1import _thread
2import copy
3import datetime
4import logging
5import threading
6import time
7import warnings
8import zoneinfo
9from collections import deque
10from contextlib import contextmanager
11
12from plain.exceptions import ImproperlyConfigured
13from plain.models.backends import utils
14from plain.models.backends.base.validation import BaseDatabaseValidation
15from plain.models.backends.signals import connection_created
16from plain.models.backends.utils import debug_transaction
17from plain.models.db import (
18 DEFAULT_DB_ALIAS,
19 DatabaseError,
20 DatabaseErrorWrapper,
21 NotSupportedError,
22)
23from plain.models.transaction import TransactionManagementError
24from plain.runtime import settings
25from plain.utils.functional import cached_property
26
27NO_DB_ALIAS = "__no_db__"
28RAN_DB_VERSION_CHECK = set()
29
30logger = logging.getLogger("plain.models.backends.base")
31
32
33class BaseDatabaseWrapper:
34 """Represent a database connection."""
35
36 # Mapping of Field objects to their column types.
37 data_types = {}
38 # Mapping of Field objects to their SQL suffix such as AUTOINCREMENT.
39 data_types_suffix = {}
40 # Mapping of Field objects to their SQL for CHECK constraints.
41 data_type_check_constraints = {}
42 ops = None
43 vendor = "unknown"
44 display_name = "unknown"
45 SchemaEditorClass = None
46 # Classes instantiated in __init__().
47 client_class = None
48 creation_class = None
49 features_class = None
50 introspection_class = None
51 ops_class = None
52 validation_class = BaseDatabaseValidation
53
54 queries_limit = 9000
55
56 def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
57 # Connection related attributes.
58 # The underlying database connection.
59 self.connection = None
60 # `settings_dict` should be a dictionary containing keys such as
61 # NAME, USER, etc. It's called `settings_dict` instead of `settings`
62 # to disambiguate it from Plain settings modules.
63 self.settings_dict = settings_dict
64 self.alias = alias
65 # Query logging in debug mode or when explicitly enabled.
66 self.queries_log = deque(maxlen=self.queries_limit)
67 self.force_debug_cursor = False
68
69 # Transaction related attributes.
70 # Tracks if the connection is in autocommit mode. Per PEP 249, by
71 # default, it isn't.
72 self.autocommit = False
73 # Tracks if the connection is in a transaction managed by 'atomic'.
74 self.in_atomic_block = False
75 # Increment to generate unique savepoint ids.
76 self.savepoint_state = 0
77 # List of savepoints created by 'atomic'.
78 self.savepoint_ids = []
79 # Stack of active 'atomic' blocks.
80 self.atomic_blocks = []
81 # Tracks if the outermost 'atomic' block should commit on exit,
82 # ie. if autocommit was active on entry.
83 self.commit_on_exit = True
84 # Tracks if the transaction should be rolled back to the next
85 # available savepoint because of an exception in an inner block.
86 self.needs_rollback = False
87 self.rollback_exc = None
88
89 # Connection termination related attributes.
90 self.close_at = None
91 self.closed_in_transaction = False
92 self.errors_occurred = False
93 self.health_check_enabled = False
94 self.health_check_done = False
95
96 # Thread-safety related attributes.
97 self._thread_sharing_lock = threading.Lock()
98 self._thread_sharing_count = 0
99 self._thread_ident = _thread.get_ident()
100
101 # A list of no-argument functions to run when the transaction commits.
102 # Each entry is an (sids, func, robust) tuple, where sids is a set of
103 # the active savepoint IDs when this function was registered and robust
104 # specifies whether it's allowed for the function to fail.
105 self.run_on_commit = []
106
107 # Should we run the on-commit hooks the next time set_autocommit(True)
108 # is called?
109 self.run_commit_hooks_on_set_autocommit_on = False
110
111 # A stack of wrappers to be invoked around execute()/executemany()
112 # calls. Each entry is a function taking five arguments: execute, sql,
113 # params, many, and context. It's the function's responsibility to
114 # call execute(sql, params, many, context).
115 self.execute_wrappers = []
116
117 self.client = self.client_class(self)
118 self.creation = self.creation_class(self)
119 self.features = self.features_class(self)
120 self.introspection = self.introspection_class(self)
121 self.ops = self.ops_class(self)
122 self.validation = self.validation_class(self)
123
124 def __repr__(self):
125 return (
126 f"<{self.__class__.__qualname__} "
127 f"vendor={self.vendor!r} alias={self.alias!r}>"
128 )
129
130 def ensure_timezone(self):
131 """
132 Ensure the connection's timezone is set to `self.timezone_name` and
133 return whether it changed or not.
134 """
135 return False
136
137 @cached_property
138 def timezone(self):
139 """
140 Return a tzinfo of the database connection time zone.
141
142 This is only used when time zone support is enabled. When a datetime is
143 read from the database, it is always returned in this time zone.
144
145 When the database backend supports time zones, it doesn't matter which
146 time zone Plain uses, as long as aware datetimes are used everywhere.
147 Other users connecting to the database can choose their own time zone.
148
149 When the database backend doesn't support time zones, the time zone
150 Plain uses may be constrained by the requirements of other users of
151 the database.
152 """
153 if not settings.USE_TZ:
154 return None
155 elif self.settings_dict["TIME_ZONE"] is None:
156 return datetime.timezone.utc
157 else:
158 return zoneinfo.ZoneInfo(self.settings_dict["TIME_ZONE"])
159
160 @cached_property
161 def timezone_name(self):
162 """
163 Name of the time zone of the database connection.
164 """
165 if not settings.USE_TZ:
166 return settings.TIME_ZONE
167 elif self.settings_dict["TIME_ZONE"] is None:
168 return "UTC"
169 else:
170 return self.settings_dict["TIME_ZONE"]
171
172 @property
173 def queries_logged(self):
174 return self.force_debug_cursor or settings.DEBUG
175
176 @property
177 def queries(self):
178 if len(self.queries_log) == self.queries_log.maxlen:
179 warnings.warn(
180 "Limit for query logging exceeded, only the last {} queries "
181 "will be returned.".format(self.queries_log.maxlen)
182 )
183 return list(self.queries_log)
184
185 def get_database_version(self):
186 """Return a tuple of the database's version."""
187 raise NotImplementedError(
188 "subclasses of BaseDatabaseWrapper may require a get_database_version() "
189 "method."
190 )
191
192 def check_database_version_supported(self):
193 """
194 Raise an error if the database version isn't supported by this
195 version of Plain.
196 """
197 if (
198 self.features.minimum_database_version is not None
199 and self.get_database_version() < self.features.minimum_database_version
200 ):
201 db_version = ".".join(map(str, self.get_database_version()))
202 min_db_version = ".".join(map(str, self.features.minimum_database_version))
203 raise NotSupportedError(
204 f"{self.display_name} {min_db_version} or later is required "
205 f"(found {db_version})."
206 )
207
208 # ##### Backend-specific methods for creating connections and cursors #####
209
210 def get_connection_params(self):
211 """Return a dict of parameters suitable for get_new_connection."""
212 raise NotImplementedError(
213 "subclasses of BaseDatabaseWrapper may require a get_connection_params() "
214 "method"
215 )
216
217 def get_new_connection(self, conn_params):
218 """Open a connection to the database."""
219 raise NotImplementedError(
220 "subclasses of BaseDatabaseWrapper may require a get_new_connection() "
221 "method"
222 )
223
224 def init_connection_state(self):
225 """Initialize the database connection settings."""
226 global RAN_DB_VERSION_CHECK
227 if self.alias not in RAN_DB_VERSION_CHECK:
228 self.check_database_version_supported()
229 RAN_DB_VERSION_CHECK.add(self.alias)
230
231 def create_cursor(self, name=None):
232 """Create a cursor. Assume that a connection is established."""
233 raise NotImplementedError(
234 "subclasses of BaseDatabaseWrapper may require a create_cursor() method"
235 )
236
237 # ##### Backend-specific methods for creating connections #####
238
239 def connect(self):
240 """Connect to the database. Assume that the connection is closed."""
241 # Check for invalid configurations.
242 self.check_settings()
243 # In case the previous connection was closed while in an atomic block
244 self.in_atomic_block = False
245 self.savepoint_ids = []
246 self.atomic_blocks = []
247 self.needs_rollback = False
248 # Reset parameters defining when to close/health-check the connection.
249 self.health_check_enabled = self.settings_dict["CONN_HEALTH_CHECKS"]
250 max_age = self.settings_dict["CONN_MAX_AGE"]
251 self.close_at = None if max_age is None else time.monotonic() + max_age
252 self.closed_in_transaction = False
253 self.errors_occurred = False
254 # New connections are healthy.
255 self.health_check_done = True
256 # Establish the connection
257 conn_params = self.get_connection_params()
258 self.connection = self.get_new_connection(conn_params)
259 self.set_autocommit(self.settings_dict["AUTOCOMMIT"])
260 self.init_connection_state()
261 connection_created.send(sender=self.__class__, connection=self)
262
263 self.run_on_commit = []
264
265 def check_settings(self):
266 if self.settings_dict["TIME_ZONE"] is not None and not settings.USE_TZ:
267 raise ImproperlyConfigured(
268 "Connection '%s' cannot set TIME_ZONE because USE_TZ is False."
269 % self.alias
270 )
271
272 def ensure_connection(self):
273 """Guarantee that a connection to the database is established."""
274 if self.connection is None:
275 with self.wrap_database_errors:
276 self.connect()
277
278 # ##### Backend-specific wrappers for PEP-249 connection methods #####
279
280 def _prepare_cursor(self, cursor):
281 """
282 Validate the connection is usable and perform database cursor wrapping.
283 """
284 self.validate_thread_sharing()
285 if self.queries_logged:
286 wrapped_cursor = self.make_debug_cursor(cursor)
287 else:
288 wrapped_cursor = self.make_cursor(cursor)
289 return wrapped_cursor
290
291 def _cursor(self, name=None):
292 self.close_if_health_check_failed()
293 self.ensure_connection()
294 with self.wrap_database_errors:
295 return self._prepare_cursor(self.create_cursor(name))
296
297 def _commit(self):
298 if self.connection is not None:
299 with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
300 return self.connection.commit()
301
302 def _rollback(self):
303 if self.connection is not None:
304 with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
305 return self.connection.rollback()
306
307 def _close(self):
308 if self.connection is not None:
309 with self.wrap_database_errors:
310 return self.connection.close()
311
312 # ##### Generic wrappers for PEP-249 connection methods #####
313
314 def cursor(self):
315 """Create a cursor, opening a connection if necessary."""
316 return self._cursor()
317
318 def commit(self):
319 """Commit a transaction and reset the dirty flag."""
320 self.validate_thread_sharing()
321 self.validate_no_atomic_block()
322 self._commit()
323 # A successful commit means that the database connection works.
324 self.errors_occurred = False
325 self.run_commit_hooks_on_set_autocommit_on = True
326
327 def rollback(self):
328 """Roll back a transaction and reset the dirty flag."""
329 self.validate_thread_sharing()
330 self.validate_no_atomic_block()
331 self._rollback()
332 # A successful rollback means that the database connection works.
333 self.errors_occurred = False
334 self.needs_rollback = False
335 self.run_on_commit = []
336
337 def close(self):
338 """Close the connection to the database."""
339 self.validate_thread_sharing()
340 self.run_on_commit = []
341
342 # Don't call validate_no_atomic_block() to avoid making it difficult
343 # to get rid of a connection in an invalid state. The next connect()
344 # will reset the transaction state anyway.
345 if self.closed_in_transaction or self.connection is None:
346 return
347 try:
348 self._close()
349 finally:
350 if self.in_atomic_block:
351 self.closed_in_transaction = True
352 self.needs_rollback = True
353 else:
354 self.connection = None
355
356 # ##### Backend-specific savepoint management methods #####
357
358 def _savepoint(self, sid):
359 with self.cursor() as cursor:
360 cursor.execute(self.ops.savepoint_create_sql(sid))
361
362 def _savepoint_rollback(self, sid):
363 with self.cursor() as cursor:
364 cursor.execute(self.ops.savepoint_rollback_sql(sid))
365
366 def _savepoint_commit(self, sid):
367 with self.cursor() as cursor:
368 cursor.execute(self.ops.savepoint_commit_sql(sid))
369
370 def _savepoint_allowed(self):
371 # Savepoints cannot be created outside a transaction
372 return self.features.uses_savepoints and not self.get_autocommit()
373
374 # ##### Generic savepoint management methods #####
375
376 def savepoint(self):
377 """
378 Create a savepoint inside the current transaction. Return an
379 identifier for the savepoint that will be used for the subsequent
380 rollback or commit. Do nothing if savepoints are not supported.
381 """
382 if not self._savepoint_allowed():
383 return
384
385 thread_ident = _thread.get_ident()
386 tid = str(thread_ident).replace("-", "")
387
388 self.savepoint_state += 1
389 sid = "s%s_x%d" % (tid, self.savepoint_state)
390
391 self.validate_thread_sharing()
392 self._savepoint(sid)
393
394 return sid
395
396 def savepoint_rollback(self, sid):
397 """
398 Roll back to a savepoint. Do nothing if savepoints are not supported.
399 """
400 if not self._savepoint_allowed():
401 return
402
403 self.validate_thread_sharing()
404 self._savepoint_rollback(sid)
405
406 # Remove any callbacks registered while this savepoint was active.
407 self.run_on_commit = [
408 (sids, func, robust)
409 for (sids, func, robust) in self.run_on_commit
410 if sid not in sids
411 ]
412
413 def savepoint_commit(self, sid):
414 """
415 Release a savepoint. Do nothing if savepoints are not supported.
416 """
417 if not self._savepoint_allowed():
418 return
419
420 self.validate_thread_sharing()
421 self._savepoint_commit(sid)
422
423 def clean_savepoints(self):
424 """
425 Reset the counter used to generate unique savepoint ids in this thread.
426 """
427 self.savepoint_state = 0
428
429 # ##### Backend-specific transaction management methods #####
430
431 def _set_autocommit(self, autocommit):
432 """
433 Backend-specific implementation to enable or disable autocommit.
434 """
435 raise NotImplementedError(
436 "subclasses of BaseDatabaseWrapper may require a _set_autocommit() method"
437 )
438
439 # ##### Generic transaction management methods #####
440
441 def get_autocommit(self):
442 """Get the autocommit state."""
443 self.ensure_connection()
444 return self.autocommit
445
446 def set_autocommit(
447 self, autocommit, force_begin_transaction_with_broken_autocommit=False
448 ):
449 """
450 Enable or disable autocommit.
451
452 The usual way to start a transaction is to turn autocommit off.
453 SQLite does not properly start a transaction when disabling
454 autocommit. To avoid this buggy behavior and to actually enter a new
455 transaction, an explicit BEGIN is required. Using
456 force_begin_transaction_with_broken_autocommit=True will issue an
457 explicit BEGIN with SQLite. This option will be ignored for other
458 backends.
459 """
460 self.validate_no_atomic_block()
461 self.close_if_health_check_failed()
462 self.ensure_connection()
463
464 start_transaction_under_autocommit = (
465 force_begin_transaction_with_broken_autocommit
466 and not autocommit
467 and hasattr(self, "_start_transaction_under_autocommit")
468 )
469
470 if start_transaction_under_autocommit:
471 self._start_transaction_under_autocommit()
472 elif autocommit:
473 self._set_autocommit(autocommit)
474 else:
475 with debug_transaction(self, "BEGIN"):
476 self._set_autocommit(autocommit)
477 self.autocommit = autocommit
478
479 if autocommit and self.run_commit_hooks_on_set_autocommit_on:
480 self.run_and_clear_commit_hooks()
481 self.run_commit_hooks_on_set_autocommit_on = False
482
483 def get_rollback(self):
484 """Get the "needs rollback" flag -- for *advanced use* only."""
485 if not self.in_atomic_block:
486 raise TransactionManagementError(
487 "The rollback flag doesn't work outside of an 'atomic' block."
488 )
489 return self.needs_rollback
490
491 def set_rollback(self, rollback):
492 """
493 Set or unset the "needs rollback" flag -- for *advanced use* only.
494 """
495 if not self.in_atomic_block:
496 raise TransactionManagementError(
497 "The rollback flag doesn't work outside of an 'atomic' block."
498 )
499 self.needs_rollback = rollback
500
501 def validate_no_atomic_block(self):
502 """Raise an error if an atomic block is active."""
503 if self.in_atomic_block:
504 raise TransactionManagementError(
505 "This is forbidden when an 'atomic' block is active."
506 )
507
508 def validate_no_broken_transaction(self):
509 if self.needs_rollback:
510 raise TransactionManagementError(
511 "An error occurred in the current transaction. You can't "
512 "execute queries until the end of the 'atomic' block."
513 ) from self.rollback_exc
514
515 # ##### Foreign key constraints checks handling #####
516
517 @contextmanager
518 def constraint_checks_disabled(self):
519 """
520 Disable foreign key constraint checking.
521 """
522 disabled = self.disable_constraint_checking()
523 try:
524 yield
525 finally:
526 if disabled:
527 self.enable_constraint_checking()
528
529 def disable_constraint_checking(self):
530 """
531 Backends can implement as needed to temporarily disable foreign key
532 constraint checking. Should return True if the constraints were
533 disabled and will need to be reenabled.
534 """
535 return False
536
537 def enable_constraint_checking(self):
538 """
539 Backends can implement as needed to re-enable foreign key constraint
540 checking.
541 """
542 pass
543
544 def check_constraints(self, table_names=None):
545 """
546 Backends can override this method if they can apply constraint
547 checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE"). Should raise an
548 IntegrityError if any invalid foreign key references are encountered.
549 """
550 pass
551
552 # ##### Connection termination handling #####
553
554 def is_usable(self):
555 """
556 Test if the database connection is usable.
557
558 This method may assume that self.connection is not None.
559
560 Actual implementations should take care not to raise exceptions
561 as that may prevent Plain from recycling unusable connections.
562 """
563 raise NotImplementedError(
564 "subclasses of BaseDatabaseWrapper may require an is_usable() method"
565 )
566
567 def close_if_health_check_failed(self):
568 """Close existing connection if it fails a health check."""
569 if (
570 self.connection is None
571 or not self.health_check_enabled
572 or self.health_check_done
573 ):
574 return
575
576 if not self.is_usable():
577 self.close()
578 self.health_check_done = True
579
580 def close_if_unusable_or_obsolete(self):
581 """
582 Close the current connection if unrecoverable errors have occurred
583 or if it outlived its maximum age.
584 """
585 if self.connection is not None:
586 self.health_check_done = False
587 # If the application didn't restore the original autocommit setting,
588 # don't take chances, drop the connection.
589 if self.get_autocommit() != self.settings_dict["AUTOCOMMIT"]:
590 self.close()
591 return
592
593 # If an exception other than DataError or IntegrityError occurred
594 # since the last commit / rollback, check if the connection works.
595 if self.errors_occurred:
596 if self.is_usable():
597 self.errors_occurred = False
598 self.health_check_done = True
599 else:
600 self.close()
601 return
602
603 if self.close_at is not None and time.monotonic() >= self.close_at:
604 self.close()
605 return
606
607 # ##### Thread safety handling #####
608
609 @property
610 def allow_thread_sharing(self):
611 with self._thread_sharing_lock:
612 return self._thread_sharing_count > 0
613
614 def inc_thread_sharing(self):
615 with self._thread_sharing_lock:
616 self._thread_sharing_count += 1
617
618 def dec_thread_sharing(self):
619 with self._thread_sharing_lock:
620 if self._thread_sharing_count <= 0:
621 raise RuntimeError(
622 "Cannot decrement the thread sharing count below zero."
623 )
624 self._thread_sharing_count -= 1
625
626 def validate_thread_sharing(self):
627 """
628 Validate that the connection isn't accessed by another thread than the
629 one which originally created it, unless the connection was explicitly
630 authorized to be shared between threads (via the `inc_thread_sharing()`
631 method). Raise an exception if the validation fails.
632 """
633 if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()):
634 raise DatabaseError(
635 "DatabaseWrapper objects created in a "
636 "thread can only be used in that same thread. The object "
637 "with alias '{}' was created in thread id {} and this is "
638 "thread id {}.".format(
639 self.alias, self._thread_ident, _thread.get_ident()
640 )
641 )
642
643 # ##### Miscellaneous #####
644
645 def prepare_database(self):
646 """
647 Hook to do any database check or preparation, generally called before
648 migrating a project or an app.
649 """
650 pass
651
652 @cached_property
653 def wrap_database_errors(self):
654 """
655 Context manager and decorator that re-throws backend-specific database
656 exceptions using Plain's common wrappers.
657 """
658 return DatabaseErrorWrapper(self)
659
660 def chunked_cursor(self):
661 """
662 Return a cursor that tries to avoid caching in the database (if
663 supported by the database), otherwise return a regular cursor.
664 """
665 return self.cursor()
666
667 def make_debug_cursor(self, cursor):
668 """Create a cursor that logs all queries in self.queries_log."""
669 return utils.CursorDebugWrapper(cursor, self)
670
671 def make_cursor(self, cursor):
672 """Create a cursor without debug logging."""
673 return utils.CursorWrapper(cursor, self)
674
675 @contextmanager
676 def temporary_connection(self):
677 """
678 Context manager that ensures that a connection is established, and
679 if it opened one, closes it to avoid leaving a dangling connection.
680 This is useful for operations outside of the request-response cycle.
681
682 Provide a cursor: with self.temporary_connection() as cursor: ...
683 """
684 must_close = self.connection is None
685 try:
686 with self.cursor() as cursor:
687 yield cursor
688 finally:
689 if must_close:
690 self.close()
691
692 @contextmanager
693 def _nodb_cursor(self):
694 """
695 Return a cursor from an alternative connection to be used when there is
696 no need to access the main database, specifically for test db
697 creation/deletion. This also prevents the production database from
698 being exposed to potential child threads while (or after) the test
699 database is destroyed. Refs #10868, #17786, #16969.
700 """
701 conn = self.__class__({**self.settings_dict, "NAME": None}, alias=NO_DB_ALIAS)
702 try:
703 with conn.cursor() as cursor:
704 yield cursor
705 finally:
706 conn.close()
707
708 def schema_editor(self, *args, **kwargs):
709 """
710 Return a new instance of this backend's SchemaEditor.
711 """
712 if self.SchemaEditorClass is None:
713 raise NotImplementedError(
714 "The SchemaEditorClass attribute of this database wrapper is still None"
715 )
716 return self.SchemaEditorClass(self, *args, **kwargs)
717
718 def on_commit(self, func, robust=False):
719 if not callable(func):
720 raise TypeError("on_commit()'s callback must be a callable.")
721 if self.in_atomic_block:
722 # Transaction in progress; save for execution on commit.
723 self.run_on_commit.append((set(self.savepoint_ids), func, robust))
724 elif not self.get_autocommit():
725 raise TransactionManagementError(
726 "on_commit() cannot be used in manual transaction management"
727 )
728 else:
729 # No transaction in progress and in autocommit mode; execute
730 # immediately.
731 if robust:
732 try:
733 func()
734 except Exception as e:
735 logger.error(
736 f"Error calling {func.__qualname__} in on_commit() (%s).",
737 e,
738 exc_info=True,
739 )
740 else:
741 func()
742
743 def run_and_clear_commit_hooks(self):
744 self.validate_no_atomic_block()
745 current_run_on_commit = self.run_on_commit
746 self.run_on_commit = []
747 while current_run_on_commit:
748 _, func, robust = current_run_on_commit.pop(0)
749 if robust:
750 try:
751 func()
752 except Exception as e:
753 logger.error(
754 f"Error calling {func.__qualname__} in on_commit() during "
755 f"transaction (%s).",
756 e,
757 exc_info=True,
758 )
759 else:
760 func()
761
762 @contextmanager
763 def execute_wrapper(self, wrapper):
764 """
765 Return a context manager under which the wrapper is applied to suitable
766 database query executions.
767 """
768 self.execute_wrappers.append(wrapper)
769 try:
770 yield
771 finally:
772 self.execute_wrappers.pop()
773
774 def copy(self, alias=None):
775 """
776 Return a copy of this connection.
777
778 For tests that require two connections to the same database.
779 """
780 settings_dict = copy.deepcopy(self.settings_dict)
781 if alias is None:
782 alias = self.alias
783 return type(self)(settings_dict, alias)