Plain is headed towards 1.0! Subscribe for development updates →

  1import _thread
  2import copy
  3import datetime
  4import logging
  5import threading
  6import time
  7import warnings
  8import zoneinfo
  9from collections import deque
 10from contextlib import contextmanager
 11from functools import cached_property
 12
 13from plain.models.backends import utils
 14from plain.models.backends.base.validation import BaseDatabaseValidation
 15from plain.models.backends.utils import debug_transaction
 16from plain.models.db import (
 17    DatabaseError,
 18    DatabaseErrorWrapper,
 19    NotSupportedError,
 20)
 21from plain.models.transaction import TransactionManagementError
 22from plain.runtime import settings
 23
 24RAN_DB_VERSION_CHECK = False
 25
 26logger = logging.getLogger("plain.models.backends.base")
 27
 28
 29class BaseDatabaseWrapper:
 30    """Represent a database connection."""
 31
 32    # Mapping of Field objects to their column types.
 33    data_types = {}
 34    # Mapping of Field objects to their SQL suffix such as AUTOINCREMENT.
 35    data_types_suffix = {}
 36    # Mapping of Field objects to their SQL for CHECK constraints.
 37    data_type_check_constraints = {}
 38    ops = None
 39    vendor = "unknown"
 40    display_name = "unknown"
 41    SchemaEditorClass = None
 42    # Classes instantiated in __init__().
 43    client_class = None
 44    creation_class = None
 45    features_class = None
 46    introspection_class = None
 47    ops_class = None
 48    validation_class = BaseDatabaseValidation
 49
 50    queries_limit = 9000
 51
 52    def __init__(self, settings_dict):
 53        # Connection related attributes.
 54        # The underlying database connection.
 55        self.connection = None
 56        # `settings_dict` should be a dictionary containing keys such as
 57        # NAME, USER, etc. It's called `settings_dict` instead of `settings`
 58        # to disambiguate it from Plain settings modules.
 59        self.settings_dict = settings_dict
 60        # Query logging in debug mode or when explicitly enabled.
 61        self.queries_log = deque(maxlen=self.queries_limit)
 62        self.force_debug_cursor = False
 63
 64        # Transaction related attributes.
 65        # Tracks if the connection is in autocommit mode. Per PEP 249, by
 66        # default, it isn't.
 67        self.autocommit = False
 68        # Tracks if the connection is in a transaction managed by 'atomic'.
 69        self.in_atomic_block = False
 70        # Increment to generate unique savepoint ids.
 71        self.savepoint_state = 0
 72        # List of savepoints created by 'atomic'.
 73        self.savepoint_ids = []
 74        # Stack of active 'atomic' blocks.
 75        self.atomic_blocks = []
 76        # Tracks if the outermost 'atomic' block should commit on exit,
 77        # ie. if autocommit was active on entry.
 78        self.commit_on_exit = True
 79        # Tracks if the transaction should be rolled back to the next
 80        # available savepoint because of an exception in an inner block.
 81        self.needs_rollback = False
 82        self.rollback_exc = None
 83
 84        # Connection termination related attributes.
 85        self.close_at = None
 86        self.closed_in_transaction = False
 87        self.errors_occurred = False
 88        self.health_check_enabled = False
 89        self.health_check_done = False
 90
 91        # Thread-safety related attributes.
 92        self._thread_sharing_lock = threading.Lock()
 93        self._thread_sharing_count = 0
 94        self._thread_ident = _thread.get_ident()
 95
 96        # A list of no-argument functions to run when the transaction commits.
 97        # Each entry is an (sids, func, robust) tuple, where sids is a set of
 98        # the active savepoint IDs when this function was registered and robust
 99        # specifies whether it's allowed for the function to fail.
100        self.run_on_commit = []
101
102        # Should we run the on-commit hooks the next time set_autocommit(True)
103        # is called?
104        self.run_commit_hooks_on_set_autocommit_on = False
105
106        # A stack of wrappers to be invoked around execute()/executemany()
107        # calls. Each entry is a function taking five arguments: execute, sql,
108        # params, many, and context. It's the function's responsibility to
109        # call execute(sql, params, many, context).
110        self.execute_wrappers = []
111
112        self.client = self.client_class(self)
113        self.creation = self.creation_class(self)
114        self.features = self.features_class(self)
115        self.introspection = self.introspection_class(self)
116        self.ops = self.ops_class(self)
117        self.validation = self.validation_class(self)
118
119    def __repr__(self):
120        return f"<{self.__class__.__qualname__} vendor={self.vendor!r}>"
121
122    def ensure_timezone(self):
123        """
124        Ensure the connection's timezone is set to `self.timezone_name` and
125        return whether it changed or not.
126        """
127        return False
128
129    @cached_property
130    def timezone(self):
131        """
132        Return a tzinfo of the database connection time zone.
133
134        This is only used when time zone support is enabled. When a datetime is
135        read from the database, it is always returned in this time zone.
136
137        When the database backend supports time zones, it doesn't matter which
138        time zone Plain uses, as long as aware datetimes are used everywhere.
139        Other users connecting to the database can choose their own time zone.
140
141        When the database backend doesn't support time zones, the time zone
142        Plain uses may be constrained by the requirements of other users of
143        the database.
144        """
145        if self.settings_dict["TIME_ZONE"] is None:
146            return datetime.UTC
147        else:
148            return zoneinfo.ZoneInfo(self.settings_dict["TIME_ZONE"])
149
150    @cached_property
151    def timezone_name(self):
152        """
153        Name of the time zone of the database connection.
154        """
155        if self.settings_dict["TIME_ZONE"] is None:
156            return "UTC"
157        else:
158            return self.settings_dict["TIME_ZONE"]
159
160    @property
161    def queries_logged(self):
162        return self.force_debug_cursor or settings.DEBUG
163
164    @property
165    def queries(self):
166        if len(self.queries_log) == self.queries_log.maxlen:
167            warnings.warn(
168                f"Limit for query logging exceeded, only the last {self.queries_log.maxlen} queries "
169                "will be returned."
170            )
171        return list(self.queries_log)
172
173    def get_database_version(self):
174        """Return a tuple of the database's version."""
175        raise NotImplementedError(
176            "subclasses of BaseDatabaseWrapper may require a get_database_version() "
177            "method."
178        )
179
180    def check_database_version_supported(self):
181        """
182        Raise an error if the database version isn't supported by this
183        version of Plain.
184        """
185        if (
186            self.features.minimum_database_version is not None
187            and self.get_database_version() < self.features.minimum_database_version
188        ):
189            db_version = ".".join(map(str, self.get_database_version()))
190            min_db_version = ".".join(map(str, self.features.minimum_database_version))
191            raise NotSupportedError(
192                f"{self.display_name} {min_db_version} or later is required "
193                f"(found {db_version})."
194            )
195
196    # ##### Backend-specific methods for creating connections and cursors #####
197
198    def get_connection_params(self):
199        """Return a dict of parameters suitable for get_new_connection."""
200        raise NotImplementedError(
201            "subclasses of BaseDatabaseWrapper may require a get_connection_params() "
202            "method"
203        )
204
205    def get_new_connection(self, conn_params):
206        """Open a connection to the database."""
207        raise NotImplementedError(
208            "subclasses of BaseDatabaseWrapper may require a get_new_connection() "
209            "method"
210        )
211
212    def init_connection_state(self):
213        """Initialize the database connection settings."""
214        global RAN_DB_VERSION_CHECK
215        if not RAN_DB_VERSION_CHECK:
216            self.check_database_version_supported()
217            RAN_DB_VERSION_CHECK = True
218
219    def create_cursor(self, name=None):
220        """Create a cursor. Assume that a connection is established."""
221        raise NotImplementedError(
222            "subclasses of BaseDatabaseWrapper may require a create_cursor() method"
223        )
224
225    # ##### Backend-specific methods for creating connections #####
226
227    def connect(self):
228        """Connect to the database. Assume that the connection is closed."""
229        # In case the previous connection was closed while in an atomic block
230        self.in_atomic_block = False
231        self.savepoint_ids = []
232        self.atomic_blocks = []
233        self.needs_rollback = False
234        # Reset parameters defining when to close/health-check the connection.
235        self.health_check_enabled = self.settings_dict["CONN_HEALTH_CHECKS"]
236        max_age = self.settings_dict["CONN_MAX_AGE"]
237        self.close_at = None if max_age is None else time.monotonic() + max_age
238        self.closed_in_transaction = False
239        self.errors_occurred = False
240        # New connections are healthy.
241        self.health_check_done = True
242        # Establish the connection
243        conn_params = self.get_connection_params()
244        self.connection = self.get_new_connection(conn_params)
245        self.set_autocommit(self.settings_dict["AUTOCOMMIT"])
246        self.init_connection_state()
247
248        self.run_on_commit = []
249
250    def ensure_connection(self):
251        """Guarantee that a connection to the database is established."""
252        if self.connection is None:
253            with self.wrap_database_errors:
254                self.connect()
255
256    # ##### Backend-specific wrappers for PEP-249 connection methods #####
257
258    def _prepare_cursor(self, cursor):
259        """
260        Validate the connection is usable and perform database cursor wrapping.
261        """
262        self.validate_thread_sharing()
263        if self.queries_logged:
264            wrapped_cursor = self.make_debug_cursor(cursor)
265        else:
266            wrapped_cursor = self.make_cursor(cursor)
267        return wrapped_cursor
268
269    def _cursor(self, name=None):
270        self.close_if_health_check_failed()
271        self.ensure_connection()
272        with self.wrap_database_errors:
273            return self._prepare_cursor(self.create_cursor(name))
274
275    def _commit(self):
276        if self.connection is not None:
277            with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
278                return self.connection.commit()
279
280    def _rollback(self):
281        if self.connection is not None:
282            with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
283                return self.connection.rollback()
284
285    def _close(self):
286        if self.connection is not None:
287            with self.wrap_database_errors:
288                return self.connection.close()
289
290    # ##### Generic wrappers for PEP-249 connection methods #####
291
292    def cursor(self):
293        """Create a cursor, opening a connection if necessary."""
294        return self._cursor()
295
296    def commit(self):
297        """Commit a transaction and reset the dirty flag."""
298        self.validate_thread_sharing()
299        self.validate_no_atomic_block()
300        self._commit()
301        # A successful commit means that the database connection works.
302        self.errors_occurred = False
303        self.run_commit_hooks_on_set_autocommit_on = True
304
305    def rollback(self):
306        """Roll back a transaction and reset the dirty flag."""
307        self.validate_thread_sharing()
308        self.validate_no_atomic_block()
309        self._rollback()
310        # A successful rollback means that the database connection works.
311        self.errors_occurred = False
312        self.needs_rollback = False
313        self.run_on_commit = []
314
315    def close(self):
316        """Close the connection to the database."""
317        self.validate_thread_sharing()
318        self.run_on_commit = []
319
320        # Don't call validate_no_atomic_block() to avoid making it difficult
321        # to get rid of a connection in an invalid state. The next connect()
322        # will reset the transaction state anyway.
323        if self.closed_in_transaction or self.connection is None:
324            return
325        try:
326            self._close()
327        finally:
328            if self.in_atomic_block:
329                self.closed_in_transaction = True
330                self.needs_rollback = True
331            else:
332                self.connection = None
333
334    # ##### Backend-specific savepoint management methods #####
335
336    def _savepoint(self, sid):
337        with self.cursor() as cursor:
338            cursor.execute(self.ops.savepoint_create_sql(sid))
339
340    def _savepoint_rollback(self, sid):
341        with self.cursor() as cursor:
342            cursor.execute(self.ops.savepoint_rollback_sql(sid))
343
344    def _savepoint_commit(self, sid):
345        with self.cursor() as cursor:
346            cursor.execute(self.ops.savepoint_commit_sql(sid))
347
348    def _savepoint_allowed(self):
349        # Savepoints cannot be created outside a transaction
350        return self.features.uses_savepoints and not self.get_autocommit()
351
352    # ##### Generic savepoint management methods #####
353
354    def savepoint(self):
355        """
356        Create a savepoint inside the current transaction. Return an
357        identifier for the savepoint that will be used for the subsequent
358        rollback or commit. Do nothing if savepoints are not supported.
359        """
360        if not self._savepoint_allowed():
361            return
362
363        thread_ident = _thread.get_ident()
364        tid = str(thread_ident).replace("-", "")
365
366        self.savepoint_state += 1
367        sid = "s%s_x%d" % (tid, self.savepoint_state)  # noqa: UP031
368
369        self.validate_thread_sharing()
370        self._savepoint(sid)
371
372        return sid
373
374    def savepoint_rollback(self, sid):
375        """
376        Roll back to a savepoint. Do nothing if savepoints are not supported.
377        """
378        if not self._savepoint_allowed():
379            return
380
381        self.validate_thread_sharing()
382        self._savepoint_rollback(sid)
383
384        # Remove any callbacks registered while this savepoint was active.
385        self.run_on_commit = [
386            (sids, func, robust)
387            for (sids, func, robust) in self.run_on_commit
388            if sid not in sids
389        ]
390
391    def savepoint_commit(self, sid):
392        """
393        Release a savepoint. Do nothing if savepoints are not supported.
394        """
395        if not self._savepoint_allowed():
396            return
397
398        self.validate_thread_sharing()
399        self._savepoint_commit(sid)
400
401    def clean_savepoints(self):
402        """
403        Reset the counter used to generate unique savepoint ids in this thread.
404        """
405        self.savepoint_state = 0
406
407    # ##### Backend-specific transaction management methods #####
408
409    def _set_autocommit(self, autocommit):
410        """
411        Backend-specific implementation to enable or disable autocommit.
412        """
413        raise NotImplementedError(
414            "subclasses of BaseDatabaseWrapper may require a _set_autocommit() method"
415        )
416
417    # ##### Generic transaction management methods #####
418
419    def get_autocommit(self):
420        """Get the autocommit state."""
421        self.ensure_connection()
422        return self.autocommit
423
424    def set_autocommit(
425        self, autocommit, force_begin_transaction_with_broken_autocommit=False
426    ):
427        """
428        Enable or disable autocommit.
429
430        The usual way to start a transaction is to turn autocommit off.
431        SQLite does not properly start a transaction when disabling
432        autocommit. To avoid this buggy behavior and to actually enter a new
433        transaction, an explicit BEGIN is required. Using
434        force_begin_transaction_with_broken_autocommit=True will issue an
435        explicit BEGIN with SQLite. This option will be ignored for other
436        backends.
437        """
438        self.validate_no_atomic_block()
439        self.close_if_health_check_failed()
440        self.ensure_connection()
441
442        start_transaction_under_autocommit = (
443            force_begin_transaction_with_broken_autocommit
444            and not autocommit
445            and hasattr(self, "_start_transaction_under_autocommit")
446        )
447
448        if start_transaction_under_autocommit:
449            self._start_transaction_under_autocommit()
450        elif autocommit:
451            self._set_autocommit(autocommit)
452        else:
453            with debug_transaction(self, "BEGIN"):
454                self._set_autocommit(autocommit)
455        self.autocommit = autocommit
456
457        if autocommit and self.run_commit_hooks_on_set_autocommit_on:
458            self.run_and_clear_commit_hooks()
459            self.run_commit_hooks_on_set_autocommit_on = False
460
461    def get_rollback(self):
462        """Get the "needs rollback" flag -- for *advanced use* only."""
463        if not self.in_atomic_block:
464            raise TransactionManagementError(
465                "The rollback flag doesn't work outside of an 'atomic' block."
466            )
467        return self.needs_rollback
468
469    def set_rollback(self, rollback):
470        """
471        Set or unset the "needs rollback" flag -- for *advanced use* only.
472        """
473        if not self.in_atomic_block:
474            raise TransactionManagementError(
475                "The rollback flag doesn't work outside of an 'atomic' block."
476            )
477        self.needs_rollback = rollback
478
479    def validate_no_atomic_block(self):
480        """Raise an error if an atomic block is active."""
481        if self.in_atomic_block:
482            raise TransactionManagementError(
483                "This is forbidden when an 'atomic' block is active."
484            )
485
486    def validate_no_broken_transaction(self):
487        if self.needs_rollback:
488            raise TransactionManagementError(
489                "An error occurred in the current transaction. You can't "
490                "execute queries until the end of the 'atomic' block."
491            ) from self.rollback_exc
492
493    # ##### Foreign key constraints checks handling #####
494
495    def disable_constraint_checking(self):
496        """
497        Backends can implement as needed to temporarily disable foreign key
498        constraint checking. Should return True if the constraints were
499        disabled and will need to be reenabled.
500        """
501        return False
502
503    def enable_constraint_checking(self):
504        """
505        Backends can implement as needed to re-enable foreign key constraint
506        checking.
507        """
508        pass
509
510    def check_constraints(self, table_names=None):
511        """
512        Backends can override this method if they can apply constraint
513        checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE"). Should raise an
514        IntegrityError if any invalid foreign key references are encountered.
515        """
516        pass
517
518    # ##### Connection termination handling #####
519
520    def is_usable(self):
521        """
522        Test if the database connection is usable.
523
524        This method may assume that self.connection is not None.
525
526        Actual implementations should take care not to raise exceptions
527        as that may prevent Plain from recycling unusable connections.
528        """
529        raise NotImplementedError(
530            "subclasses of BaseDatabaseWrapper may require an is_usable() method"
531        )
532
533    def close_if_health_check_failed(self):
534        """Close existing connection if it fails a health check."""
535        if (
536            self.connection is None
537            or not self.health_check_enabled
538            or self.health_check_done
539        ):
540            return
541
542        if not self.is_usable():
543            self.close()
544        self.health_check_done = True
545
546    def close_if_unusable_or_obsolete(self):
547        """
548        Close the current connection if unrecoverable errors have occurred
549        or if it outlived its maximum age.
550        """
551        if self.connection is not None:
552            self.health_check_done = False
553            # If the application didn't restore the original autocommit setting,
554            # don't take chances, drop the connection.
555            if self.get_autocommit() != self.settings_dict["AUTOCOMMIT"]:
556                self.close()
557                return
558
559            # If an exception other than DataError or IntegrityError occurred
560            # since the last commit / rollback, check if the connection works.
561            if self.errors_occurred:
562                if self.is_usable():
563                    self.errors_occurred = False
564                    self.health_check_done = True
565                else:
566                    self.close()
567                    return
568
569            if self.close_at is not None and time.monotonic() >= self.close_at:
570                self.close()
571                return
572
573    # ##### Thread safety handling #####
574
575    @property
576    def allow_thread_sharing(self):
577        with self._thread_sharing_lock:
578            return self._thread_sharing_count > 0
579
580    def validate_thread_sharing(self):
581        """
582        Validate that the connection isn't accessed by another thread than the
583        one which originally created it, unless the connection was explicitly
584        authorized to be shared between threads (via the `inc_thread_sharing()`
585        method). Raise an exception if the validation fails.
586        """
587        if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()):
588            raise DatabaseError(
589                "DatabaseWrapper objects created in a "
590                "thread can only be used in that same thread. The connection "
591                f"was created in thread id {self._thread_ident} and this is "
592                f"thread id {_thread.get_ident()}."
593            )
594
595    # ##### Miscellaneous #####
596
597    def prepare_database(self):
598        """
599        Hook to do any database check or preparation, generally called before
600        migrating a project or an app.
601        """
602        pass
603
604    @cached_property
605    def wrap_database_errors(self):
606        """
607        Context manager and decorator that re-throws backend-specific database
608        exceptions using Plain's common wrappers.
609        """
610        return DatabaseErrorWrapper(self)
611
612    def chunked_cursor(self):
613        """
614        Return a cursor that tries to avoid caching in the database (if
615        supported by the database), otherwise return a regular cursor.
616        """
617        return self.cursor()
618
619    def make_debug_cursor(self, cursor):
620        """Create a cursor that logs all queries in self.queries_log."""
621        return utils.CursorDebugWrapper(cursor, self)
622
623    def make_cursor(self, cursor):
624        """Create a cursor without debug logging."""
625        return utils.CursorWrapper(cursor, self)
626
627    @contextmanager
628    def temporary_connection(self):
629        """
630        Context manager that ensures that a connection is established, and
631        if it opened one, closes it to avoid leaving a dangling connection.
632        This is useful for operations outside of the request-response cycle.
633
634        Provide a cursor: with self.temporary_connection() as cursor: ...
635        """
636        must_close = self.connection is None
637        try:
638            with self.cursor() as cursor:
639                yield cursor
640        finally:
641            if must_close:
642                self.close()
643
644    @contextmanager
645    def _nodb_cursor(self):
646        """
647        Return a cursor from an alternative connection to be used when there is
648        no need to access the main database, specifically for test db
649        creation/deletion. This also prevents the production database from
650        being exposed to potential child threads while (or after) the test
651        database is destroyed. Refs #10868, #17786, #16969.
652        """
653        conn = self.__class__({**self.settings_dict, "NAME": None})
654        try:
655            with conn.cursor() as cursor:
656                yield cursor
657        finally:
658            conn.close()
659
660    def schema_editor(self, *args, **kwargs):
661        """
662        Return a new instance of this backend's SchemaEditor.
663        """
664        if self.SchemaEditorClass is None:
665            raise NotImplementedError(
666                "The SchemaEditorClass attribute of this database wrapper is still None"
667            )
668        return self.SchemaEditorClass(self, *args, **kwargs)
669
670    def on_commit(self, func, robust=False):
671        if not callable(func):
672            raise TypeError("on_commit()'s callback must be a callable.")
673        if self.in_atomic_block:
674            # Transaction in progress; save for execution on commit.
675            self.run_on_commit.append((set(self.savepoint_ids), func, robust))
676        elif not self.get_autocommit():
677            raise TransactionManagementError(
678                "on_commit() cannot be used in manual transaction management"
679            )
680        else:
681            # No transaction in progress and in autocommit mode; execute
682            # immediately.
683            if robust:
684                try:
685                    func()
686                except Exception as e:
687                    logger.error(
688                        f"Error calling {func.__qualname__} in on_commit() (%s).",
689                        e,
690                        exc_info=True,
691                    )
692            else:
693                func()
694
695    def run_and_clear_commit_hooks(self):
696        self.validate_no_atomic_block()
697        current_run_on_commit = self.run_on_commit
698        self.run_on_commit = []
699        while current_run_on_commit:
700            _, func, robust = current_run_on_commit.pop(0)
701            if robust:
702                try:
703                    func()
704                except Exception as e:
705                    logger.error(
706                        f"Error calling {func.__qualname__} in on_commit() during "
707                        f"transaction (%s).",
708                        e,
709                        exc_info=True,
710                    )
711            else:
712                func()
713
714    @contextmanager
715    def execute_wrapper(self, wrapper):
716        """
717        Return a context manager under which the wrapper is applied to suitable
718        database query executions.
719        """
720        self.execute_wrappers.append(wrapper)
721        try:
722            yield
723        finally:
724            self.execute_wrappers.pop()
725
726    def copy(self):
727        """
728        Return a copy of this connection.
729
730        For tests that require two connections to the same database.
731        """
732        settings_dict = copy.deepcopy(self.settings_dict)
733        return type(self)(settings_dict)