v0.146.0
  1from __future__ import annotations
  2
  3import _thread
  4import warnings
  5from collections import deque
  6from collections.abc import Generator, Sequence
  7from contextlib import contextmanager
  8from typing import TYPE_CHECKING, Any, LiteralString, NamedTuple, cast
  9
 10import psycopg
 11from psycopg import errors
 12from psycopg import sql as psycopg_sql
 13
 14from plain.logs import get_framework_logger
 15from plain.postgres import utils
 16from plain.postgres.dialect import quote_name
 17from plain.postgres.fields import GenericIPAddressField, TimeField, UUIDField
 18from plain.postgres.schema import DatabaseSchemaEditor
 19from plain.postgres.sources import ConnectionSource
 20from plain.postgres.transaction import TransactionManagementError
 21from plain.postgres.utils import CursorDebugWrapper as BaseCursorDebugWrapper
 22from plain.postgres.utils import CursorWrapper, debug_transaction
 23from plain.runtime import settings
 24
 25if TYPE_CHECKING:
 26    from psycopg import Connection as PsycopgConnection
 27
 28    from plain.postgres.database_url import DatabaseConfig
 29    from plain.postgres.fields import Field
 30
 31logger = get_framework_logger()
 32
 33
 34def get_migratable_models() -> Generator[Any]:
 35    """Return all models that should be included in migrations."""
 36    from plain.packages import packages_registry
 37    from plain.postgres import models_registry
 38
 39    return (
 40        model
 41        for package_config in packages_registry.get_package_configs()
 42        for model in models_registry.get_models(
 43            package_label=package_config.package_label
 44        )
 45    )
 46
 47
 48class TableInfo(NamedTuple):
 49    """Structure returned by DatabaseConnection.get_table_list()."""
 50
 51    name: str
 52    type: str
 53    comment: str | None
 54
 55
 56class DatabaseConnection:
 57    """
 58    PostgreSQL database connection.
 59
 60    This is the only database backend supported by Plain.
 61    """
 62
 63    queries_limit: int = 9000
 64
 65    ignored_tables: list[str] = []
 66
 67    def __init__(self, source: ConnectionSource):
 68        # Lazy — acquired on first use via self._source.
 69        self.connection: PsycopgConnection[Any] | None = None
 70        self._source: ConnectionSource = source
 71        # Query logging in debug mode or when explicitly enabled.
 72        self.queries_log: deque[dict[str, Any]] = deque(maxlen=self.queries_limit)
 73        self.force_debug_cursor: bool = False
 74
 75        # Transaction related attributes.
 76        # Tracks if the connection is in autocommit mode. Per PEP 249, by
 77        # default, it isn't.
 78        self.autocommit: bool = False
 79        # Tracks if the connection is in a transaction managed by 'atomic'.
 80        self.in_atomic_block: bool = False
 81        # Increment to generate unique savepoint ids.
 82        self.savepoint_state: int = 0
 83        # List of savepoints created by 'atomic'.
 84        self.savepoint_ids: list[str | None] = []
 85        # Stack of active 'atomic' blocks.
 86        self.atomic_blocks: list[Any] = []
 87        # Tracks if the transaction should be rolled back to the next
 88        # available savepoint because of an exception in an inner block.
 89        self.needs_rollback: bool = False
 90        self.rollback_exc: Exception | None = None
 91
 92        # A list of no-argument functions to run when the transaction commits.
 93        # Each entry is an (sids, func, robust) tuple, where sids is a set of
 94        # the active savepoint IDs when this function was registered and robust
 95        # specifies whether it's allowed for the function to fail.
 96        self.run_on_commit: list[tuple[set[str | None], Any, bool]] = []
 97
 98        # Should we run the on-commit hooks the next time set_autocommit(True)
 99        # is called?
100        self.run_commit_hooks_on_set_autocommit_on: bool = False
101
102        # A stack of wrappers to be invoked around execute()/executemany()
103        # calls. Each entry is a function taking five arguments: execute, sql,
104        # params, many, and context. It's the function's responsibility to
105        # call execute(sql, params, many, context).
106        self.execute_wrappers: list[Any] = []
107
108    def __repr__(self) -> str:
109        return f"<{self.__class__.__qualname__} vendor='postgresql'>"
110
111    def __del__(self) -> None:
112        # Safety net for wrappers GC'd without an explicit close() —
113        # e.g. inside a short-lived `asyncio.to_thread` context copy.
114        # Returns the pooled connection to the pool. Guards handle
115        # interpreter shutdown, when attrs may already be cleared.
116        conn = getattr(self, "connection", None)
117        if conn is None:
118            return
119        source = getattr(self, "_source", None)
120        if source is None:
121            return
122        try:
123            source.release(conn)
124        except Exception:
125            pass
126
127    @property
128    def settings_dict(self) -> DatabaseConfig:
129        """Config of the server this wrapper talks to. For pool-backed
130        wrappers this always reflects the live `POSTGRES_URL`."""
131        return self._source.config
132
133    @property
134    def queries_logged(self) -> bool:
135        return self.force_debug_cursor or settings.DEBUG
136
137    @property
138    def queries(self) -> list[dict[str, Any]]:
139        if len(self.queries_log) == self.queries_log.maxlen:
140            warnings.warn(
141                f"Limit for query logging exceeded, only the last {self.queries_log.maxlen} queries "
142                "will be returned."
143            )
144        return list(self.queries_log)
145
146    # ##### Connection and cursor methods #####
147
148    def _set_autocommit(self, autocommit: bool) -> None:
149        """Backend-specific implementation to enable or disable autocommit."""
150        assert self.connection is not None
151        self.connection.autocommit = autocommit
152
153    def check_constraints(self, table_names: list[str] | None = None) -> None:
154        """
155        Check constraints by setting them to immediate. Return them to deferred
156        afterward.
157        """
158        with self.cursor() as cursor:
159            cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
160            cursor.execute("SET CONSTRAINTS ALL DEFERRED")
161
162    def make_debug_cursor(self, cursor: psycopg.Cursor[Any]) -> CursorDebugWrapper:
163        return CursorDebugWrapper(cursor, self)
164
165    # ##### Connection lifecycle #####
166
167    def connect(self) -> None:
168        """Connect to the database. Assume that the connection is closed."""
169        self.connection = self._source.acquire()
170        self.set_autocommit(True)
171
172    def ensure_connection(self) -> None:
173        """Guarantee that a connection to the database is established."""
174        if self.connection is None:
175            self.connect()
176
177    # ##### PEP-249 connection method wrappers #####
178
179    def _prepare_cursor(self, cursor: psycopg.Cursor[Any]) -> utils.CursorWrapper:
180        """
181        Validate the connection is usable and perform database cursor wrapping.
182        """
183        if self.queries_logged:
184            wrapped_cursor = self.make_debug_cursor(cursor)
185        else:
186            wrapped_cursor = self.make_cursor(cursor)
187        return wrapped_cursor
188
189    def _cursor(self) -> utils.CursorWrapper:
190        self.ensure_connection()
191        assert self.connection is not None
192        return self._prepare_cursor(self.connection.cursor())
193
194    def _commit(self) -> None:
195        if self.connection is not None:
196            with debug_transaction(self, "COMMIT"):
197                return self.connection.commit()
198
199    def _rollback(self) -> None:
200        if self.connection is not None:
201            with debug_transaction(self, "ROLLBACK"):
202                return self.connection.rollback()
203
204    # ##### Generic wrappers for PEP-249 connection methods #####
205
206    def cursor(self) -> utils.CursorWrapper:
207        """Create a cursor, opening a connection if necessary."""
208        return self._cursor()
209
210    def commit(self) -> None:
211        """Commit a transaction and reset the dirty flag."""
212        self.validate_no_atomic_block()
213        self._commit()
214        self.run_commit_hooks_on_set_autocommit_on = True
215
216    def rollback(self) -> None:
217        """Roll back a transaction and reset the dirty flag."""
218        self.validate_no_atomic_block()
219        self._rollback()
220        self.needs_rollback = False
221        self.run_on_commit = []
222
223    def close(self) -> None:
224        """Close the connection to the database."""
225        # Closing mid-atomic would reopen a fresh autocommit connection on
226        # the next cursor() and silently run the rest of the block outside
227        # its transaction. Callers that drop a connection during error
228        # recovery (see Atomic.__exit__) unwind the atomic state first.
229        self.validate_no_atomic_block()
230
231        self.run_on_commit = []
232        if self.connection is None:
233            return
234        try:
235            self._source.release(self.connection)
236        finally:
237            # Null the reference so __del__ (and ensure_connection) can't
238            # touch an already-released psycopg connection.
239            self.connection = None
240
241    # ##### Savepoint management #####
242
243    def _savepoint(self, sid: str) -> None:
244        with self.cursor() as cursor:
245            cursor.execute(f"SAVEPOINT {quote_name(sid)}")
246
247    def _savepoint_rollback(self, sid: str) -> None:
248        with self.cursor() as cursor:
249            cursor.execute(f"ROLLBACK TO SAVEPOINT {quote_name(sid)}")
250
251    def _savepoint_commit(self, sid: str) -> None:
252        with self.cursor() as cursor:
253            cursor.execute(f"RELEASE SAVEPOINT {quote_name(sid)}")
254
255    # ##### Generic savepoint management methods #####
256
257    def savepoint(self) -> str | None:
258        """
259        Create a savepoint inside the current transaction. Return an
260        identifier for the savepoint that will be used for the subsequent
261        rollback or commit. Return None if in autocommit mode (no transaction).
262        """
263        if self.get_autocommit():
264            return None
265
266        thread_ident = _thread.get_ident()
267        tid = str(thread_ident).replace("-", "")
268
269        self.savepoint_state += 1
270        sid = "s%s_x%d" % (tid, self.savepoint_state)  # noqa: UP031
271
272        self._savepoint(sid)
273
274        return sid
275
276    def savepoint_rollback(self, sid: str) -> None:
277        """
278        Roll back to a savepoint. Do nothing if in autocommit mode.
279        """
280        if self.get_autocommit():
281            return
282
283        self._savepoint_rollback(sid)
284
285        # Remove any callbacks registered while this savepoint was active.
286        self.run_on_commit = [
287            (sids, func, robust)
288            for (sids, func, robust) in self.run_on_commit
289            if sid not in sids
290        ]
291
292    def savepoint_commit(self, sid: str) -> None:
293        """
294        Release a savepoint. Do nothing if in autocommit mode.
295        """
296        if self.get_autocommit():
297            return
298
299        self._savepoint_commit(sid)
300
301    def clean_savepoints(self) -> None:
302        """
303        Reset the counter used to generate unique savepoint ids in this thread.
304        """
305        self.savepoint_state = 0
306
307    # ##### Generic transaction management methods #####
308
309    def get_autocommit(self) -> bool:
310        """Get the autocommit state."""
311        self.ensure_connection()
312        return self.autocommit
313
314    def set_autocommit(self, autocommit: bool) -> None:
315        """
316        Enable or disable autocommit.
317
318        Used internally by atomic() to manage transactions. Don't call this
319        directly — use atomic() instead.
320        """
321        self.validate_no_atomic_block()
322        self.ensure_connection()
323
324        if autocommit:
325            self._set_autocommit(autocommit)
326        else:
327            with debug_transaction(self, "BEGIN"):
328                self._set_autocommit(autocommit)
329        self.autocommit = autocommit
330
331        if autocommit and self.run_commit_hooks_on_set_autocommit_on:
332            self.run_and_clear_commit_hooks()
333            self.run_commit_hooks_on_set_autocommit_on = False
334
335    def get_rollback(self) -> bool:
336        """Get the "needs rollback" flag -- for *advanced use* only."""
337        if not self.in_atomic_block:
338            raise TransactionManagementError(
339                "The rollback flag doesn't work outside of an 'atomic' block."
340            )
341        return self.needs_rollback
342
343    def set_rollback(self, rollback: bool) -> None:
344        """
345        Set or unset the "needs rollback" flag -- for *advanced use* only.
346        """
347        if not self.in_atomic_block:
348            raise TransactionManagementError(
349                "The rollback flag doesn't work outside of an 'atomic' block."
350            )
351        self.needs_rollback = rollback
352
353    def validate_no_atomic_block(self) -> None:
354        """Raise an error if an atomic block is active."""
355        if self.in_atomic_block:
356            raise TransactionManagementError(
357                "This is forbidden when an 'atomic' block is active."
358            )
359
360    def validate_no_broken_transaction(self) -> None:
361        if self.needs_rollback:
362            raise TransactionManagementError(
363                "An error occurred in the current transaction. You can't "
364                "execute queries until the end of the 'atomic' block."
365            ) from self.rollback_exc
366
367    # ##### Miscellaneous #####
368
369    def make_cursor(self, cursor: psycopg.Cursor[Any]) -> utils.CursorWrapper:
370        """Create a cursor without debug logging."""
371        return utils.CursorWrapper(cursor, self)
372
373    def schema_editor(self, *args: Any, **kwargs: Any) -> DatabaseSchemaEditor:
374        """Return a new instance of the schema editor."""
375        return DatabaseSchemaEditor(self, *args, **kwargs)
376
377    def on_commit(self, func: Any, robust: bool = False) -> None:
378        if not callable(func):
379            raise TypeError("on_commit()'s callback must be a callable.")
380        if self.in_atomic_block:
381            # Transaction in progress; save for execution on commit.
382            self.run_on_commit.append((set(self.savepoint_ids), func, robust))
383        else:
384            # No transaction in progress; execute immediately.
385            if robust:
386                try:
387                    func()
388                except Exception as e:
389                    logger.error(
390                        "Error calling on_commit() handler",
391                        exc_info=True,
392                        extra={"handler": func.__qualname__, "error": str(e)},
393                    )
394            else:
395                func()
396
397    def run_and_clear_commit_hooks(self) -> None:
398        self.validate_no_atomic_block()
399        current_run_on_commit = self.run_on_commit
400        self.run_on_commit = []
401        while current_run_on_commit:
402            _, func, robust = current_run_on_commit.pop(0)
403            if robust:
404                try:
405                    func()
406                except Exception as e:
407                    logger.error(
408                        "Error calling on_commit() handler during transaction",
409                        exc_info=True,
410                        extra={"handler": func.__qualname__, "error": str(e)},
411                    )
412            else:
413                func()
414
415    @contextmanager
416    def execute_wrapper(self, wrapper: Any) -> Generator[None]:
417        """
418        Return a context manager under which the wrapper is applied to suitable
419        database query executions.
420        """
421        self.execute_wrappers.append(wrapper)
422        try:
423            yield
424        finally:
425            self.execute_wrappers.pop()
426
427    # ##### SQL generation methods that require connection state #####
428
429    def compose_sql(self, query: str, params: Any) -> str:
430        """
431        Compose a SQL query with parameters using psycopg's mogrify.
432
433        This requires an active connection because it uses the connection's
434        cursor to properly format parameters.
435        """
436        assert self.connection is not None
437        return psycopg.ClientCursor(self.connection).mogrify(
438            psycopg_sql.SQL(cast(LiteralString, query)), params
439        )
440
441    def last_executed_query(
442        self,
443        cursor: utils.CursorWrapper,
444        sql: str,
445        params: Any,
446    ) -> str | None:
447        """
448        Return a string of the query last executed by the given cursor, with
449        placeholders replaced with actual values.
450        """
451        try:
452            return self.compose_sql(sql, params)
453        except errors.DataError:
454            return None
455
456    def unification_cast_sql(self, output_field: Field) -> str:
457        """
458        Given a field instance, return the SQL that casts the result of a union
459        to that type. The resulting string should contain a '%s' placeholder
460        for the expression being cast.
461        """
462        if isinstance(output_field, GenericIPAddressField | TimeField | UUIDField):
463            # PostgreSQL will resolve a union as type 'text' if input types are
464            # 'unknown'.
465            # https://www.postgresql.org/docs/current/typeconv-union-case.html
466            # These fields cannot be implicitly cast back in the default
467            # PostgreSQL configuration so we need to explicitly cast them.
468            # We must also remove components of the type within brackets:
469            # varchar(255) -> varchar.
470            db_type = output_field.db_type()
471            if db_type:
472                return "CAST(%s AS {})".format(db_type.split("(")[0])
473        return "%s"
474
475    # ##### Introspection methods #####
476
477    def table_names(
478        self, cursor: CursorWrapper | None = None, include_views: bool = False
479    ) -> list[str]:
480        """
481        Return a list of names of all tables that exist in the database.
482        Sort the returned table list by Python's default sorting. Do NOT use
483        the database's ORDER BY here to avoid subtle differences in sorting
484        order between databases.
485        """
486
487        def get_names(cursor: CursorWrapper) -> list[str]:
488            return sorted(
489                ti.name
490                for ti in self.get_table_list(cursor)
491                if include_views or ti.type == "t"
492            )
493
494        if cursor is None:
495            with self.cursor() as cursor:
496                return get_names(cursor)
497        return get_names(cursor)
498
499    def get_table_list(self, cursor: CursorWrapper) -> Sequence[TableInfo]:
500        """
501        Return an unsorted list of TableInfo named tuples of all tables and
502        views that exist in the database.
503        """
504        cursor.execute(
505            """
506            SELECT
507                c.relname,
508                CASE
509                    WHEN c.relispartition THEN 'p'
510                    WHEN c.relkind IN ('m', 'v') THEN 'v'
511                    ELSE 't'
512                END,
513                obj_description(c.oid, 'pg_class')
514            FROM pg_catalog.pg_class c
515            LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
516            WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
517                AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
518                AND pg_catalog.pg_table_is_visible(c.oid)
519        """
520        )
521        return [
522            TableInfo(*row)
523            for row in cursor.fetchall()
524            if row[0] not in self.ignored_tables
525        ]
526
527    def plain_table_names(
528        self, only_existing: bool = False, include_views: bool = True
529    ) -> list[str]:
530        """
531        Return a list of all table names that have associated Plain models and
532        are in INSTALLED_PACKAGES.
533
534        If only_existing is True, include only the tables in the database.
535        """
536        tables = set()
537        for model in get_migratable_models():
538            tables.add(model.model_options.db_table)
539            tables.update(
540                f.m2m_db_table() for f in model._model_meta.local_many_to_many
541            )
542        tables = list(tables)
543        if only_existing:
544            existing_tables = set(self.table_names(include_views=include_views))
545            tables = [t for t in tables if t in existing_tables]
546        return tables
547
548    def get_sequences(
549        self, cursor: CursorWrapper, table_name: str, table_fields: tuple[Any, ...] = ()
550    ) -> list[dict[str, Any]]:
551        """
552        Return a list of introspected sequences for table_name. Each sequence
553        is a dict: {'table': <table_name>, 'column': <column_name>, 'name': <sequence_name>}.
554        """
555        cursor.execute(
556            """
557            SELECT
558                s.relname AS sequence_name,
559                a.attname AS colname
560            FROM
561                pg_class s
562                JOIN pg_depend d ON d.objid = s.oid
563                    AND d.classid = 'pg_class'::regclass
564                    AND d.refclassid = 'pg_class'::regclass
565                JOIN pg_attribute a ON d.refobjid = a.attrelid
566                    AND d.refobjsubid = a.attnum
567                JOIN pg_class tbl ON tbl.oid = d.refobjid
568                    AND tbl.relname = %s
569                    AND pg_catalog.pg_table_is_visible(tbl.oid)
570            WHERE
571                s.relkind = 'S';
572        """,
573            [table_name],
574        )
575        return [
576            {"name": row[0], "table": table_name, "column": row[1]}
577            for row in cursor.fetchall()
578        ]
579
580    def get_constraints(
581        self, cursor: CursorWrapper, table_name: str
582    ) -> dict[str, dict[str, Any]]:
583        """
584        Retrieve any constraints or keys (unique, pk, fk, check, index) across
585        one or more columns. Also retrieve the definition of expression-based
586        indexes.
587        """
588        constraints: dict[str, dict[str, Any]] = {}
589        # Loop over the key table, collecting things as constraints. The column
590        # array must return column names in the same order in which they were
591        # created.
592        cursor.execute(
593            """
594            SELECT
595                c.conname,
596                array(
597                    SELECT attname
598                    FROM unnest(c.conkey) WITH ORDINALITY cols(colid, arridx)
599                    JOIN pg_attribute AS ca ON cols.colid = ca.attnum
600                    WHERE ca.attrelid = c.conrelid
601                    ORDER BY cols.arridx
602                ),
603                c.contype,
604                (SELECT fkc.relname || '.' || fka.attname
605                FROM pg_attribute AS fka
606                JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
607                WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]),
608                c.convalidated,
609                pg_get_constraintdef(c.oid),
610                c.confdeltype
611            FROM pg_constraint AS c
612            JOIN pg_class AS cl ON c.conrelid = cl.oid
613            WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
614        """,
615            [table_name],
616        )
617        for (
618            constraint,
619            columns,
620            kind,
621            used_cols,
622            validated,
623            constraintdef,
624            confdeltype,
625        ) in cursor.fetchall():
626            constraints[constraint] = {
627                "columns": columns,
628                "foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None,
629                "contype": kind,
630                "index": False,
631                "definition": constraintdef,
632                "validated": validated,
633                "on_delete_action": confdeltype if kind == "f" else None,
634            }
635        # Now get indexes. Sort order, opclasses, INCLUDE, and predicates all
636        # ride along inside `pg_get_indexdef` and are compared via the
637        # canonical-tail round-trip in convergence — no need to introspect
638        # them here as separate columns.
639        cursor.execute(
640            """
641            SELECT
642                indexname,
643                array_agg(attname ORDER BY arridx),
644                indisunique,
645                amname,
646                exprdef,
647                indisvalid
648            FROM (
649                SELECT
650                    c2.relname as indexname, idx.*, attr.attname, am.amname,
651                    pg_get_indexdef(idx.indexrelid) AS exprdef
652                FROM (
653                    SELECT *
654                    FROM
655                        pg_index i,
656                        unnest(i.indkey)
657                            WITH ORDINALITY koi(key, arridx)
658                ) idx
659                LEFT JOIN pg_class c ON idx.indrelid = c.oid
660                LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
661                LEFT JOIN pg_am am ON c2.relam = am.oid
662                LEFT JOIN
663                    pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
664                WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
665            ) s2
666            GROUP BY
667                indexname, indisunique, amname, exprdef, indisvalid;
668        """,
669            [table_name],
670        )
671        for (
672            index,
673            columns,
674            unique,
675            type_,
676            definition,
677            valid,
678        ) in cursor.fetchall():
679            if index not in constraints:
680                constraints[index] = {
681                    "columns": columns if columns != [None] else [],
682                    "unique": unique,
683                    "index": True,
684                    "type": type_,
685                    "definition": definition,
686                    "valid": valid,
687                }
688        return constraints
689
690
691class CursorDebugWrapper(BaseCursorDebugWrapper):
692    def copy(self, statement: Any) -> Any:
693        with self.debug_sql(statement):
694            return self.cursor.copy(statement)