Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import _thread
  4import copy
  5import datetime
  6import logging
  7import threading
  8import time
  9import warnings
 10import zoneinfo
 11from collections import deque
 12from collections.abc import Generator
 13from contextlib import contextmanager
 14from functools import cached_property
 15from typing import TYPE_CHECKING, Any
 16
 17from plain.models.backends import utils
 18from plain.models.backends.base.validation import BaseDatabaseValidation
 19from plain.models.backends.utils import debug_transaction
 20from plain.models.db import (
 21    DatabaseError,
 22    DatabaseErrorWrapper,
 23    NotSupportedError,
 24)
 25from plain.models.transaction import TransactionManagementError
 26from plain.runtime import settings
 27
 28if TYPE_CHECKING:
 29    from plain.models.backends.base.client import BaseDatabaseClient
 30    from plain.models.backends.base.creation import BaseDatabaseCreation
 31    from plain.models.backends.base.features import BaseDatabaseFeatures
 32    from plain.models.backends.base.introspection import BaseDatabaseIntrospection
 33    from plain.models.backends.base.operations import BaseDatabaseOperations
 34    from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
 35
 36RAN_DB_VERSION_CHECK = False
 37
 38logger = logging.getLogger("plain.models.backends.base")
 39
 40
 41class BaseDatabaseWrapper:
 42    """Represent a database connection."""
 43
 44    # Mapping of Field objects to their column types.
 45    data_types: dict[str, str] = {}
 46    # Mapping of Field objects to their SQL suffix such as AUTOINCREMENT.
 47    data_types_suffix: dict[str, str] = {}
 48    # Mapping of Field objects to their SQL for CHECK constraints.
 49    data_type_check_constraints: dict[str, str] = {}
 50    # Instance attributes - always assigned in __init__
 51    ops: BaseDatabaseOperations
 52    client: BaseDatabaseClient
 53    creation: BaseDatabaseCreation
 54    features: BaseDatabaseFeatures
 55    introspection: BaseDatabaseIntrospection
 56    validation: BaseDatabaseValidation
 57    vendor: str = "unknown"
 58    display_name: str = "unknown"
 59    SchemaEditorClass: type[BaseDatabaseSchemaEditor] | None = None
 60    # Classes instantiated in __init__().
 61    client_class: type[BaseDatabaseClient] | None = None
 62    creation_class: type[BaseDatabaseCreation] | None = None
 63    features_class: type[BaseDatabaseFeatures] | None = None
 64    introspection_class: type[BaseDatabaseIntrospection] | None = None
 65    ops_class: type[BaseDatabaseOperations] | None = None
 66    validation_class: type[BaseDatabaseValidation] = BaseDatabaseValidation
 67
 68    queries_limit: int = 9000
 69
 70    def __init__(self, settings_dict: dict[str, Any]):
 71        # Connection related attributes.
 72        # The underlying database connection.
 73        self.connection: BaseDatabaseWrapper | None = None
 74        # `settings_dict` should be a dictionary containing keys such as
 75        # NAME, USER, etc. It's called `settings_dict` instead of `settings`
 76        # to disambiguate it from Plain settings modules.
 77        self.settings_dict: dict[str, Any] = settings_dict
 78        # Query logging in debug mode or when explicitly enabled.
 79        self.queries_log: deque[dict[str, Any]] = deque(maxlen=self.queries_limit)
 80        self.force_debug_cursor: bool = False
 81
 82        # Transaction related attributes.
 83        # Tracks if the connection is in autocommit mode. Per PEP 249, by
 84        # default, it isn't.
 85        self.autocommit: bool = False
 86        # Tracks if the connection is in a transaction managed by 'atomic'.
 87        self.in_atomic_block: bool = False
 88        # Increment to generate unique savepoint ids.
 89        self.savepoint_state: int = 0
 90        # List of savepoints created by 'atomic'.
 91        self.savepoint_ids: list[str] = []
 92        # Stack of active 'atomic' blocks.
 93        self.atomic_blocks: list[Any] = []
 94        # Tracks if the outermost 'atomic' block should commit on exit,
 95        # ie. if autocommit was active on entry.
 96        self.commit_on_exit: bool = True
 97        # Tracks if the transaction should be rolled back to the next
 98        # available savepoint because of an exception in an inner block.
 99        self.needs_rollback: bool = False
100        self.rollback_exc: Exception | None = None
101
102        # Connection termination related attributes.
103        self.close_at: float | None = None
104        self.closed_in_transaction: bool = False
105        self.errors_occurred: bool = False
106        self.health_check_enabled: bool = False
107        self.health_check_done: bool = False
108
109        # Thread-safety related attributes.
110        self._thread_sharing_lock: threading.Lock = threading.Lock()
111        self._thread_sharing_count: int = 0
112        self._thread_ident: int = _thread.get_ident()
113
114        # A list of no-argument functions to run when the transaction commits.
115        # Each entry is an (sids, func, robust) tuple, where sids is a set of
116        # the active savepoint IDs when this function was registered and robust
117        # specifies whether it's allowed for the function to fail.
118        self.run_on_commit: list[tuple[set[str], Any, bool]] = []
119
120        # Should we run the on-commit hooks the next time set_autocommit(True)
121        # is called?
122        self.run_commit_hooks_on_set_autocommit_on: bool = False
123
124        # A stack of wrappers to be invoked around execute()/executemany()
125        # calls. Each entry is a function taking five arguments: execute, sql,
126        # params, many, and context. It's the function's responsibility to
127        # call execute(sql, params, many, context).
128        self.execute_wrappers: list[Any] = []
129
130        self.client: BaseDatabaseClient = self.client_class(self)  # type: ignore[misc]
131        self.creation: BaseDatabaseCreation = self.creation_class(self)  # type: ignore[misc]
132        self.features: BaseDatabaseFeatures = self.features_class(self)  # type: ignore[misc]
133        self.introspection: BaseDatabaseIntrospection = self.introspection_class(self)  # type: ignore[misc]
134        self.ops: BaseDatabaseOperations = self.ops_class(self)  # type: ignore[misc]
135        self.validation: BaseDatabaseValidation = self.validation_class(self)
136
137    def __repr__(self) -> str:
138        return f"<{self.__class__.__qualname__} vendor={self.vendor!r}>"
139
140    def ensure_timezone(self) -> bool:
141        """
142        Ensure the connection's timezone is set to `self.timezone_name` and
143        return whether it changed or not.
144        """
145        return False
146
147    @cached_property
148    def timezone(self) -> datetime.tzinfo:
149        """
150        Return a tzinfo of the database connection time zone.
151
152        This is only used when time zone support is enabled. When a datetime is
153        read from the database, it is always returned in this time zone.
154
155        When the database backend supports time zones, it doesn't matter which
156        time zone Plain uses, as long as aware datetimes are used everywhere.
157        Other users connecting to the database can choose their own time zone.
158
159        When the database backend doesn't support time zones, the time zone
160        Plain uses may be constrained by the requirements of other users of
161        the database.
162        """
163        if self.settings_dict["TIME_ZONE"] is None:
164            return datetime.UTC
165        else:
166            return zoneinfo.ZoneInfo(self.settings_dict["TIME_ZONE"])
167
168    @cached_property
169    def timezone_name(self) -> str:
170        """
171        Name of the time zone of the database connection.
172        """
173        if self.settings_dict["TIME_ZONE"] is None:
174            return "UTC"
175        else:
176            return self.settings_dict["TIME_ZONE"]
177
178    @property
179    def queries_logged(self) -> bool:
180        return self.force_debug_cursor or settings.DEBUG
181
182    @property
183    def queries(self) -> list[dict[str, Any]]:
184        if len(self.queries_log) == self.queries_log.maxlen:
185            warnings.warn(
186                f"Limit for query logging exceeded, only the last {self.queries_log.maxlen} queries "
187                "will be returned."
188            )
189        return list(self.queries_log)
190
191    def get_database_version(self) -> tuple[int, ...]:
192        """Return a tuple of the database's version."""
193        raise NotImplementedError(
194            "subclasses of BaseDatabaseWrapper may require a get_database_version() "
195            "method."
196        )
197
198    def check_database_version_supported(self) -> None:
199        """
200        Raise an error if the database version isn't supported by this
201        version of Plain.
202        """
203        if (
204            self.features.minimum_database_version is not None
205            and self.get_database_version() < self.features.minimum_database_version
206        ):
207            db_version = ".".join(map(str, self.get_database_version()))
208            min_db_version = ".".join(map(str, self.features.minimum_database_version))
209            raise NotSupportedError(
210                f"{self.display_name} {min_db_version} or later is required "
211                f"(found {db_version})."
212            )
213
214    # ##### Backend-specific methods for creating connections and cursors #####
215
216    def get_connection_params(self) -> dict[str, Any]:
217        """Return a dict of parameters suitable for get_new_connection."""
218        raise NotImplementedError(
219            "subclasses of BaseDatabaseWrapper may require a get_connection_params() "
220            "method"
221        )
222
223    def get_new_connection(self, conn_params: dict[str, Any]) -> Any:
224        """Open a connection to the database."""
225        raise NotImplementedError(
226            "subclasses of BaseDatabaseWrapper may require a get_new_connection() "
227            "method"
228        )
229
230    def init_connection_state(self) -> None:
231        """Initialize the database connection settings."""
232        global RAN_DB_VERSION_CHECK
233        if not RAN_DB_VERSION_CHECK:
234            self.check_database_version_supported()
235            RAN_DB_VERSION_CHECK = True
236
237    def create_cursor(self, name: str | None = None) -> Any:
238        """Create a cursor. Assume that a connection is established."""
239        raise NotImplementedError(
240            "subclasses of BaseDatabaseWrapper may require a create_cursor() method"
241        )
242
243    # ##### Backend-specific methods for creating connections #####
244
245    def connect(self) -> None:
246        """Connect to the database. Assume that the connection is closed."""
247        # In case the previous connection was closed while in an atomic block
248        self.in_atomic_block = False
249        self.savepoint_ids = []
250        self.atomic_blocks = []
251        self.needs_rollback = False
252        # Reset parameters defining when to close/health-check the connection.
253        self.health_check_enabled = self.settings_dict["CONN_HEALTH_CHECKS"]
254        max_age = self.settings_dict["CONN_MAX_AGE"]
255        self.close_at = None if max_age is None else time.monotonic() + max_age
256        self.closed_in_transaction = False
257        self.errors_occurred = False
258        # New connections are healthy.
259        self.health_check_done = True
260        # Establish the connection
261        conn_params = self.get_connection_params()
262        self.connection = self.get_new_connection(conn_params)
263        self.set_autocommit(self.settings_dict["AUTOCOMMIT"])
264        self.init_connection_state()
265
266        self.run_on_commit = []
267
268    def ensure_connection(self) -> None:
269        """Guarantee that a connection to the database is established."""
270        if self.connection is None:
271            with self.wrap_database_errors:
272                self.connect()
273
274    # ##### Backend-specific wrappers for PEP-249 connection methods #####
275
276    def _prepare_cursor(self, cursor: Any) -> utils.CursorWrapper:
277        """
278        Validate the connection is usable and perform database cursor wrapping.
279        """
280        self.validate_thread_sharing()
281        if self.queries_logged:
282            wrapped_cursor = self.make_debug_cursor(cursor)
283        else:
284            wrapped_cursor = self.make_cursor(cursor)
285        return wrapped_cursor
286
287    def _cursor(self, name: str | None = None) -> utils.CursorWrapper:
288        self.close_if_health_check_failed()
289        self.ensure_connection()
290        with self.wrap_database_errors:
291            return self._prepare_cursor(self.create_cursor(name))
292
293    def _commit(self) -> None:
294        if self.connection is not None:
295            with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
296                return self.connection.commit()
297
298    def _rollback(self) -> None:
299        if self.connection is not None:
300            with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
301                return self.connection.rollback()
302
303    def _close(self) -> None:
304        if self.connection is not None:
305            with self.wrap_database_errors:
306                return self.connection.close()
307
308    # ##### Generic wrappers for PEP-249 connection methods #####
309
310    def cursor(self) -> utils.CursorWrapper:
311        """Create a cursor, opening a connection if necessary."""
312        return self._cursor()
313
314    def commit(self) -> None:
315        """Commit a transaction and reset the dirty flag."""
316        self.validate_thread_sharing()
317        self.validate_no_atomic_block()
318        self._commit()
319        # A successful commit means that the database connection works.
320        self.errors_occurred = False
321        self.run_commit_hooks_on_set_autocommit_on = True
322
323    def rollback(self) -> None:
324        """Roll back a transaction and reset the dirty flag."""
325        self.validate_thread_sharing()
326        self.validate_no_atomic_block()
327        self._rollback()
328        # A successful rollback means that the database connection works.
329        self.errors_occurred = False
330        self.needs_rollback = False
331        self.run_on_commit = []
332
333    def close(self) -> None:
334        """Close the connection to the database."""
335        self.validate_thread_sharing()
336        self.run_on_commit = []
337
338        # Don't call validate_no_atomic_block() to avoid making it difficult
339        # to get rid of a connection in an invalid state. The next connect()
340        # will reset the transaction state anyway.
341        if self.closed_in_transaction or self.connection is None:
342            return
343        try:
344            self._close()
345        finally:
346            if self.in_atomic_block:
347                self.closed_in_transaction = True
348                self.needs_rollback = True
349            else:
350                self.connection = None
351
352    # ##### Backend-specific savepoint management methods #####
353
354    def _savepoint(self, sid: str) -> None:
355        with self.cursor() as cursor:
356            cursor.execute(self.ops.savepoint_create_sql(sid))
357
358    def _savepoint_rollback(self, sid: str) -> None:
359        with self.cursor() as cursor:
360            cursor.execute(self.ops.savepoint_rollback_sql(sid))
361
362    def _savepoint_commit(self, sid: str) -> None:
363        with self.cursor() as cursor:
364            cursor.execute(self.ops.savepoint_commit_sql(sid))
365
366    def _savepoint_allowed(self) -> bool:
367        # Savepoints cannot be created outside a transaction
368        return self.features.uses_savepoints and not self.get_autocommit()
369
370    # ##### Generic savepoint management methods #####
371
372    def savepoint(self) -> str | None:
373        """
374        Create a savepoint inside the current transaction. Return an
375        identifier for the savepoint that will be used for the subsequent
376        rollback or commit. Do nothing if savepoints are not supported.
377        """
378        if not self._savepoint_allowed():
379            return None
380
381        thread_ident = _thread.get_ident()
382        tid = str(thread_ident).replace("-", "")
383
384        self.savepoint_state += 1
385        sid = "s%s_x%d" % (tid, self.savepoint_state)  # noqa: UP031
386
387        self.validate_thread_sharing()
388        self._savepoint(sid)
389
390        return sid
391
392    def savepoint_rollback(self, sid: str) -> None:
393        """
394        Roll back to a savepoint. Do nothing if savepoints are not supported.
395        """
396        if not self._savepoint_allowed():
397            return
398
399        self.validate_thread_sharing()
400        self._savepoint_rollback(sid)
401
402        # Remove any callbacks registered while this savepoint was active.
403        self.run_on_commit = [
404            (sids, func, robust)
405            for (sids, func, robust) in self.run_on_commit
406            if sid not in sids
407        ]
408
409    def savepoint_commit(self, sid: str) -> None:
410        """
411        Release a savepoint. Do nothing if savepoints are not supported.
412        """
413        if not self._savepoint_allowed():
414            return
415
416        self.validate_thread_sharing()
417        self._savepoint_commit(sid)
418
419    def clean_savepoints(self) -> None:
420        """
421        Reset the counter used to generate unique savepoint ids in this thread.
422        """
423        self.savepoint_state = 0
424
425    # ##### Backend-specific transaction management methods #####
426
427    def _set_autocommit(self, autocommit: bool) -> None:
428        """
429        Backend-specific implementation to enable or disable autocommit.
430        """
431        raise NotImplementedError(
432            "subclasses of BaseDatabaseWrapper may require a _set_autocommit() method"
433        )
434
435    # ##### Generic transaction management methods #####
436
437    def get_autocommit(self) -> bool:
438        """Get the autocommit state."""
439        self.ensure_connection()
440        return self.autocommit
441
442    def set_autocommit(
443        self,
444        autocommit: bool,
445        force_begin_transaction_with_broken_autocommit: bool = False,
446    ) -> None:
447        """
448        Enable or disable autocommit.
449
450        The usual way to start a transaction is to turn autocommit off.
451        SQLite does not properly start a transaction when disabling
452        autocommit. To avoid this buggy behavior and to actually enter a new
453        transaction, an explicit BEGIN is required. Using
454        force_begin_transaction_with_broken_autocommit=True will issue an
455        explicit BEGIN with SQLite. This option will be ignored for other
456        backends.
457        """
458        self.validate_no_atomic_block()
459        self.close_if_health_check_failed()
460        self.ensure_connection()
461
462        start_transaction_under_autocommit = (
463            force_begin_transaction_with_broken_autocommit
464            and not autocommit
465            and hasattr(self, "_start_transaction_under_autocommit")
466        )
467
468        if start_transaction_under_autocommit:
469            self._start_transaction_under_autocommit()  # type: ignore[attr-defined]
470        elif autocommit:
471            self._set_autocommit(autocommit)
472        else:
473            with debug_transaction(self, "BEGIN"):
474                self._set_autocommit(autocommit)
475        self.autocommit = autocommit
476
477        if autocommit and self.run_commit_hooks_on_set_autocommit_on:
478            self.run_and_clear_commit_hooks()
479            self.run_commit_hooks_on_set_autocommit_on = False
480
481    def get_rollback(self) -> bool:
482        """Get the "needs rollback" flag -- for *advanced use* only."""
483        if not self.in_atomic_block:
484            raise TransactionManagementError(
485                "The rollback flag doesn't work outside of an 'atomic' block."
486            )
487        return self.needs_rollback
488
489    def set_rollback(self, rollback: bool) -> None:
490        """
491        Set or unset the "needs rollback" flag -- for *advanced use* only.
492        """
493        if not self.in_atomic_block:
494            raise TransactionManagementError(
495                "The rollback flag doesn't work outside of an 'atomic' block."
496            )
497        self.needs_rollback = rollback
498
499    def validate_no_atomic_block(self) -> None:
500        """Raise an error if an atomic block is active."""
501        if self.in_atomic_block:
502            raise TransactionManagementError(
503                "This is forbidden when an 'atomic' block is active."
504            )
505
506    def validate_no_broken_transaction(self) -> None:
507        if self.needs_rollback:
508            raise TransactionManagementError(
509                "An error occurred in the current transaction. You can't "
510                "execute queries until the end of the 'atomic' block."
511            ) from self.rollback_exc
512
513    # ##### Foreign key constraints checks handling #####
514
515    def disable_constraint_checking(self) -> bool:
516        """
517        Backends can implement as needed to temporarily disable foreign key
518        constraint checking. Should return True if the constraints were
519        disabled and will need to be reenabled.
520        """
521        return False
522
523    def enable_constraint_checking(self) -> None:
524        """
525        Backends can implement as needed to re-enable foreign key constraint
526        checking.
527        """
528        pass
529
530    def check_constraints(self, table_names: list[str] | None = None) -> None:
531        """
532        Backends can override this method if they can apply constraint
533        checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE"). Should raise an
534        IntegrityError if any invalid foreign key references are encountered.
535        """
536        pass
537
538    # ##### Connection termination handling #####
539
540    def is_usable(self) -> bool:
541        """
542        Test if the database connection is usable.
543
544        This method may assume that self.connection is not None.
545
546        Actual implementations should take care not to raise exceptions
547        as that may prevent Plain from recycling unusable connections.
548        """
549        raise NotImplementedError(
550            "subclasses of BaseDatabaseWrapper may require an is_usable() method"
551        )
552
553    def close_if_health_check_failed(self) -> None:
554        """Close existing connection if it fails a health check."""
555        if (
556            self.connection is None
557            or not self.health_check_enabled
558            or self.health_check_done
559        ):
560            return
561
562        if not self.is_usable():
563            self.close()
564        self.health_check_done = True
565
566    def close_if_unusable_or_obsolete(self) -> None:
567        """
568        Close the current connection if unrecoverable errors have occurred
569        or if it outlived its maximum age.
570        """
571        if self.connection is not None:
572            self.health_check_done = False
573            # If the application didn't restore the original autocommit setting,
574            # don't take chances, drop the connection.
575            if self.get_autocommit() != self.settings_dict["AUTOCOMMIT"]:
576                self.close()
577                return
578
579            # If an exception other than DataError or IntegrityError occurred
580            # since the last commit / rollback, check if the connection works.
581            if self.errors_occurred:
582                if self.is_usable():
583                    self.errors_occurred = False
584                    self.health_check_done = True
585                else:
586                    self.close()
587                    return
588
589            if self.close_at is not None and time.monotonic() >= self.close_at:
590                self.close()
591                return
592
593    # ##### Thread safety handling #####
594
595    @property
596    def allow_thread_sharing(self) -> bool:
597        with self._thread_sharing_lock:
598            return self._thread_sharing_count > 0
599
600    def validate_thread_sharing(self) -> None:
601        """
602        Validate that the connection isn't accessed by another thread than the
603        one which originally created it, unless the connection was explicitly
604        authorized to be shared between threads (via the `inc_thread_sharing()`
605        method). Raise an exception if the validation fails.
606        """
607        if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()):
608            raise DatabaseError(
609                "DatabaseWrapper objects created in a "
610                "thread can only be used in that same thread. The connection "
611                f"was created in thread id {self._thread_ident} and this is "
612                f"thread id {_thread.get_ident()}."
613            )
614
615    # ##### Miscellaneous #####
616
617    def prepare_database(self) -> None:
618        """
619        Hook to do any database check or preparation, generally called before
620        migrating a project or an app.
621        """
622        pass
623
624    @cached_property
625    def wrap_database_errors(self) -> DatabaseErrorWrapper:
626        """
627        Context manager and decorator that re-throws backend-specific database
628        exceptions using Plain's common wrappers.
629        """
630        return DatabaseErrorWrapper(self)
631
632    def chunked_cursor(self) -> utils.CursorWrapper:
633        """
634        Return a cursor that tries to avoid caching in the database (if
635        supported by the database), otherwise return a regular cursor.
636        """
637        return self.cursor()
638
639    def make_debug_cursor(self, cursor: Any) -> utils.CursorDebugWrapper:
640        """Create a cursor that logs all queries in self.queries_log."""
641        return utils.CursorDebugWrapper(cursor, self)
642
643    def make_cursor(self, cursor: Any) -> utils.CursorWrapper:
644        """Create a cursor without debug logging."""
645        return utils.CursorWrapper(cursor, self)
646
647    @contextmanager
648    def temporary_connection(self) -> Generator[utils.CursorWrapper, None, None]:
649        """
650        Context manager that ensures that a connection is established, and
651        if it opened one, closes it to avoid leaving a dangling connection.
652        This is useful for operations outside of the request-response cycle.
653
654        Provide a cursor: with self.temporary_connection() as cursor: ...
655        """
656        must_close = self.connection is None
657        try:
658            with self.cursor() as cursor:
659                yield cursor
660        finally:
661            if must_close:
662                self.close()
663
664    @contextmanager
665    def _nodb_cursor(self) -> Generator[utils.CursorWrapper, None, None]:
666        """
667        Return a cursor from an alternative connection to be used when there is
668        no need to access the main database, specifically for test db
669        creation/deletion. This also prevents the production database from
670        being exposed to potential child threads while (or after) the test
671        database is destroyed. Refs #10868, #17786, #16969.
672        """
673        conn = self.__class__({**self.settings_dict, "NAME": None})
674        try:
675            with conn.cursor() as cursor:
676                yield cursor
677        finally:
678            conn.close()
679
680    def schema_editor(self, *args: Any, **kwargs: Any) -> BaseDatabaseSchemaEditor:
681        """
682        Return a new instance of this backend's SchemaEditor.
683        """
684        if self.SchemaEditorClass is None:
685            raise NotImplementedError(
686                "The SchemaEditorClass attribute of this database wrapper is still None"
687            )
688        return self.SchemaEditorClass(self, *args, **kwargs)
689
690    def on_commit(self, func: Any, robust: bool = False) -> None:
691        if not callable(func):
692            raise TypeError("on_commit()'s callback must be a callable.")
693        if self.in_atomic_block:
694            # Transaction in progress; save for execution on commit.
695            self.run_on_commit.append((set(self.savepoint_ids), func, robust))
696        elif not self.get_autocommit():
697            raise TransactionManagementError(
698                "on_commit() cannot be used in manual transaction management"
699            )
700        else:
701            # No transaction in progress and in autocommit mode; execute
702            # immediately.
703            if robust:
704                try:
705                    func()
706                except Exception as e:
707                    logger.error(
708                        f"Error calling {func.__qualname__} in on_commit() (%s).",
709                        e,
710                        exc_info=True,
711                    )
712            else:
713                func()
714
715    def run_and_clear_commit_hooks(self) -> None:
716        self.validate_no_atomic_block()
717        current_run_on_commit = self.run_on_commit
718        self.run_on_commit = []
719        while current_run_on_commit:
720            _, func, robust = current_run_on_commit.pop(0)
721            if robust:
722                try:
723                    func()
724                except Exception as e:
725                    logger.error(
726                        f"Error calling {func.__qualname__} in on_commit() during "
727                        f"transaction (%s).",
728                        e,
729                        exc_info=True,
730                    )
731            else:
732                func()
733
734    @contextmanager
735    def execute_wrapper(self, wrapper: Any) -> Generator[None, None, None]:
736        """
737        Return a context manager under which the wrapper is applied to suitable
738        database query executions.
739        """
740        self.execute_wrappers.append(wrapper)
741        try:
742            yield
743        finally:
744            self.execute_wrappers.pop()
745
746    def copy(self) -> BaseDatabaseWrapper:
747        """
748        Return a copy of this connection.
749
750        For tests that require two connections to the same database.
751        """
752        settings_dict = copy.deepcopy(self.settings_dict)
753        return type(self)(settings_dict)