1import _thread
2import copy
3import datetime
4import logging
5import threading
6import time
7import warnings
8import zoneinfo
9from collections import deque
10from contextlib import contextmanager
11from functools import cached_property
12
13from plain.models.backends import utils
14from plain.models.backends.base.validation import BaseDatabaseValidation
15from plain.models.backends.utils import debug_transaction
16from plain.models.db import (
17 DatabaseError,
18 DatabaseErrorWrapper,
19 NotSupportedError,
20)
21from plain.models.transaction import TransactionManagementError
22from plain.runtime import settings
23
24RAN_DB_VERSION_CHECK = False
25
26logger = logging.getLogger("plain.models.backends.base")
27
28
29class BaseDatabaseWrapper:
30 """Represent a database connection."""
31
32 # Mapping of Field objects to their column types.
33 data_types = {}
34 # Mapping of Field objects to their SQL suffix such as AUTOINCREMENT.
35 data_types_suffix = {}
36 # Mapping of Field objects to their SQL for CHECK constraints.
37 data_type_check_constraints = {}
38 ops = None
39 vendor = "unknown"
40 display_name = "unknown"
41 SchemaEditorClass = None
42 # Classes instantiated in __init__().
43 client_class = None
44 creation_class = None
45 features_class = None
46 introspection_class = None
47 ops_class = None
48 validation_class = BaseDatabaseValidation
49
50 queries_limit = 9000
51
52 def __init__(self, settings_dict):
53 # Connection related attributes.
54 # The underlying database connection.
55 self.connection = None
56 # `settings_dict` should be a dictionary containing keys such as
57 # NAME, USER, etc. It's called `settings_dict` instead of `settings`
58 # to disambiguate it from Plain settings modules.
59 self.settings_dict = settings_dict
60 # Query logging in debug mode or when explicitly enabled.
61 self.queries_log = deque(maxlen=self.queries_limit)
62 self.force_debug_cursor = False
63
64 # Transaction related attributes.
65 # Tracks if the connection is in autocommit mode. Per PEP 249, by
66 # default, it isn't.
67 self.autocommit = False
68 # Tracks if the connection is in a transaction managed by 'atomic'.
69 self.in_atomic_block = False
70 # Increment to generate unique savepoint ids.
71 self.savepoint_state = 0
72 # List of savepoints created by 'atomic'.
73 self.savepoint_ids = []
74 # Stack of active 'atomic' blocks.
75 self.atomic_blocks = []
76 # Tracks if the outermost 'atomic' block should commit on exit,
77 # ie. if autocommit was active on entry.
78 self.commit_on_exit = True
79 # Tracks if the transaction should be rolled back to the next
80 # available savepoint because of an exception in an inner block.
81 self.needs_rollback = False
82 self.rollback_exc = None
83
84 # Connection termination related attributes.
85 self.close_at = None
86 self.closed_in_transaction = False
87 self.errors_occurred = False
88 self.health_check_enabled = False
89 self.health_check_done = False
90
91 # Thread-safety related attributes.
92 self._thread_sharing_lock = threading.Lock()
93 self._thread_sharing_count = 0
94 self._thread_ident = _thread.get_ident()
95
96 # A list of no-argument functions to run when the transaction commits.
97 # Each entry is an (sids, func, robust) tuple, where sids is a set of
98 # the active savepoint IDs when this function was registered and robust
99 # specifies whether it's allowed for the function to fail.
100 self.run_on_commit = []
101
102 # Should we run the on-commit hooks the next time set_autocommit(True)
103 # is called?
104 self.run_commit_hooks_on_set_autocommit_on = False
105
106 # A stack of wrappers to be invoked around execute()/executemany()
107 # calls. Each entry is a function taking five arguments: execute, sql,
108 # params, many, and context. It's the function's responsibility to
109 # call execute(sql, params, many, context).
110 self.execute_wrappers = []
111
112 self.client = self.client_class(self)
113 self.creation = self.creation_class(self)
114 self.features = self.features_class(self)
115 self.introspection = self.introspection_class(self)
116 self.ops = self.ops_class(self)
117 self.validation = self.validation_class(self)
118
119 def __repr__(self):
120 return f"<{self.__class__.__qualname__} vendor={self.vendor!r}>"
121
122 def ensure_timezone(self):
123 """
124 Ensure the connection's timezone is set to `self.timezone_name` and
125 return whether it changed or not.
126 """
127 return False
128
129 @cached_property
130 def timezone(self):
131 """
132 Return a tzinfo of the database connection time zone.
133
134 This is only used when time zone support is enabled. When a datetime is
135 read from the database, it is always returned in this time zone.
136
137 When the database backend supports time zones, it doesn't matter which
138 time zone Plain uses, as long as aware datetimes are used everywhere.
139 Other users connecting to the database can choose their own time zone.
140
141 When the database backend doesn't support time zones, the time zone
142 Plain uses may be constrained by the requirements of other users of
143 the database.
144 """
145 if self.settings_dict["TIME_ZONE"] is None:
146 return datetime.UTC
147 else:
148 return zoneinfo.ZoneInfo(self.settings_dict["TIME_ZONE"])
149
150 @cached_property
151 def timezone_name(self):
152 """
153 Name of the time zone of the database connection.
154 """
155 if self.settings_dict["TIME_ZONE"] is None:
156 return "UTC"
157 else:
158 return self.settings_dict["TIME_ZONE"]
159
160 @property
161 def queries_logged(self):
162 return self.force_debug_cursor or settings.DEBUG
163
164 @property
165 def queries(self):
166 if len(self.queries_log) == self.queries_log.maxlen:
167 warnings.warn(
168 f"Limit for query logging exceeded, only the last {self.queries_log.maxlen} queries "
169 "will be returned."
170 )
171 return list(self.queries_log)
172
173 def get_database_version(self):
174 """Return a tuple of the database's version."""
175 raise NotImplementedError(
176 "subclasses of BaseDatabaseWrapper may require a get_database_version() "
177 "method."
178 )
179
180 def check_database_version_supported(self):
181 """
182 Raise an error if the database version isn't supported by this
183 version of Plain.
184 """
185 if (
186 self.features.minimum_database_version is not None
187 and self.get_database_version() < self.features.minimum_database_version
188 ):
189 db_version = ".".join(map(str, self.get_database_version()))
190 min_db_version = ".".join(map(str, self.features.minimum_database_version))
191 raise NotSupportedError(
192 f"{self.display_name} {min_db_version} or later is required "
193 f"(found {db_version})."
194 )
195
196 # ##### Backend-specific methods for creating connections and cursors #####
197
198 def get_connection_params(self):
199 """Return a dict of parameters suitable for get_new_connection."""
200 raise NotImplementedError(
201 "subclasses of BaseDatabaseWrapper may require a get_connection_params() "
202 "method"
203 )
204
205 def get_new_connection(self, conn_params):
206 """Open a connection to the database."""
207 raise NotImplementedError(
208 "subclasses of BaseDatabaseWrapper may require a get_new_connection() "
209 "method"
210 )
211
212 def init_connection_state(self):
213 """Initialize the database connection settings."""
214 global RAN_DB_VERSION_CHECK
215 if not RAN_DB_VERSION_CHECK:
216 self.check_database_version_supported()
217 RAN_DB_VERSION_CHECK = True
218
219 def create_cursor(self, name=None):
220 """Create a cursor. Assume that a connection is established."""
221 raise NotImplementedError(
222 "subclasses of BaseDatabaseWrapper may require a create_cursor() method"
223 )
224
225 # ##### Backend-specific methods for creating connections #####
226
227 def connect(self):
228 """Connect to the database. Assume that the connection is closed."""
229 # In case the previous connection was closed while in an atomic block
230 self.in_atomic_block = False
231 self.savepoint_ids = []
232 self.atomic_blocks = []
233 self.needs_rollback = False
234 # Reset parameters defining when to close/health-check the connection.
235 self.health_check_enabled = self.settings_dict["CONN_HEALTH_CHECKS"]
236 max_age = self.settings_dict["CONN_MAX_AGE"]
237 self.close_at = None if max_age is None else time.monotonic() + max_age
238 self.closed_in_transaction = False
239 self.errors_occurred = False
240 # New connections are healthy.
241 self.health_check_done = True
242 # Establish the connection
243 conn_params = self.get_connection_params()
244 self.connection = self.get_new_connection(conn_params)
245 self.set_autocommit(self.settings_dict["AUTOCOMMIT"])
246 self.init_connection_state()
247
248 self.run_on_commit = []
249
250 def ensure_connection(self):
251 """Guarantee that a connection to the database is established."""
252 if self.connection is None:
253 with self.wrap_database_errors:
254 self.connect()
255
256 # ##### Backend-specific wrappers for PEP-249 connection methods #####
257
258 def _prepare_cursor(self, cursor):
259 """
260 Validate the connection is usable and perform database cursor wrapping.
261 """
262 self.validate_thread_sharing()
263 if self.queries_logged:
264 wrapped_cursor = self.make_debug_cursor(cursor)
265 else:
266 wrapped_cursor = self.make_cursor(cursor)
267 return wrapped_cursor
268
269 def _cursor(self, name=None):
270 self.close_if_health_check_failed()
271 self.ensure_connection()
272 with self.wrap_database_errors:
273 return self._prepare_cursor(self.create_cursor(name))
274
275 def _commit(self):
276 if self.connection is not None:
277 with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
278 return self.connection.commit()
279
280 def _rollback(self):
281 if self.connection is not None:
282 with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
283 return self.connection.rollback()
284
285 def _close(self):
286 if self.connection is not None:
287 with self.wrap_database_errors:
288 return self.connection.close()
289
290 # ##### Generic wrappers for PEP-249 connection methods #####
291
292 def cursor(self):
293 """Create a cursor, opening a connection if necessary."""
294 return self._cursor()
295
296 def commit(self):
297 """Commit a transaction and reset the dirty flag."""
298 self.validate_thread_sharing()
299 self.validate_no_atomic_block()
300 self._commit()
301 # A successful commit means that the database connection works.
302 self.errors_occurred = False
303 self.run_commit_hooks_on_set_autocommit_on = True
304
305 def rollback(self):
306 """Roll back a transaction and reset the dirty flag."""
307 self.validate_thread_sharing()
308 self.validate_no_atomic_block()
309 self._rollback()
310 # A successful rollback means that the database connection works.
311 self.errors_occurred = False
312 self.needs_rollback = False
313 self.run_on_commit = []
314
315 def close(self):
316 """Close the connection to the database."""
317 self.validate_thread_sharing()
318 self.run_on_commit = []
319
320 # Don't call validate_no_atomic_block() to avoid making it difficult
321 # to get rid of a connection in an invalid state. The next connect()
322 # will reset the transaction state anyway.
323 if self.closed_in_transaction or self.connection is None:
324 return
325 try:
326 self._close()
327 finally:
328 if self.in_atomic_block:
329 self.closed_in_transaction = True
330 self.needs_rollback = True
331 else:
332 self.connection = None
333
334 # ##### Backend-specific savepoint management methods #####
335
336 def _savepoint(self, sid):
337 with self.cursor() as cursor:
338 cursor.execute(self.ops.savepoint_create_sql(sid))
339
340 def _savepoint_rollback(self, sid):
341 with self.cursor() as cursor:
342 cursor.execute(self.ops.savepoint_rollback_sql(sid))
343
344 def _savepoint_commit(self, sid):
345 with self.cursor() as cursor:
346 cursor.execute(self.ops.savepoint_commit_sql(sid))
347
348 def _savepoint_allowed(self):
349 # Savepoints cannot be created outside a transaction
350 return self.features.uses_savepoints and not self.get_autocommit()
351
352 # ##### Generic savepoint management methods #####
353
354 def savepoint(self):
355 """
356 Create a savepoint inside the current transaction. Return an
357 identifier for the savepoint that will be used for the subsequent
358 rollback or commit. Do nothing if savepoints are not supported.
359 """
360 if not self._savepoint_allowed():
361 return
362
363 thread_ident = _thread.get_ident()
364 tid = str(thread_ident).replace("-", "")
365
366 self.savepoint_state += 1
367 sid = "s%s_x%d" % (tid, self.savepoint_state) # noqa: UP031
368
369 self.validate_thread_sharing()
370 self._savepoint(sid)
371
372 return sid
373
374 def savepoint_rollback(self, sid):
375 """
376 Roll back to a savepoint. Do nothing if savepoints are not supported.
377 """
378 if not self._savepoint_allowed():
379 return
380
381 self.validate_thread_sharing()
382 self._savepoint_rollback(sid)
383
384 # Remove any callbacks registered while this savepoint was active.
385 self.run_on_commit = [
386 (sids, func, robust)
387 for (sids, func, robust) in self.run_on_commit
388 if sid not in sids
389 ]
390
391 def savepoint_commit(self, sid):
392 """
393 Release a savepoint. Do nothing if savepoints are not supported.
394 """
395 if not self._savepoint_allowed():
396 return
397
398 self.validate_thread_sharing()
399 self._savepoint_commit(sid)
400
401 def clean_savepoints(self):
402 """
403 Reset the counter used to generate unique savepoint ids in this thread.
404 """
405 self.savepoint_state = 0
406
407 # ##### Backend-specific transaction management methods #####
408
409 def _set_autocommit(self, autocommit):
410 """
411 Backend-specific implementation to enable or disable autocommit.
412 """
413 raise NotImplementedError(
414 "subclasses of BaseDatabaseWrapper may require a _set_autocommit() method"
415 )
416
417 # ##### Generic transaction management methods #####
418
419 def get_autocommit(self):
420 """Get the autocommit state."""
421 self.ensure_connection()
422 return self.autocommit
423
424 def set_autocommit(
425 self, autocommit, force_begin_transaction_with_broken_autocommit=False
426 ):
427 """
428 Enable or disable autocommit.
429
430 The usual way to start a transaction is to turn autocommit off.
431 SQLite does not properly start a transaction when disabling
432 autocommit. To avoid this buggy behavior and to actually enter a new
433 transaction, an explicit BEGIN is required. Using
434 force_begin_transaction_with_broken_autocommit=True will issue an
435 explicit BEGIN with SQLite. This option will be ignored for other
436 backends.
437 """
438 self.validate_no_atomic_block()
439 self.close_if_health_check_failed()
440 self.ensure_connection()
441
442 start_transaction_under_autocommit = (
443 force_begin_transaction_with_broken_autocommit
444 and not autocommit
445 and hasattr(self, "_start_transaction_under_autocommit")
446 )
447
448 if start_transaction_under_autocommit:
449 self._start_transaction_under_autocommit()
450 elif autocommit:
451 self._set_autocommit(autocommit)
452 else:
453 with debug_transaction(self, "BEGIN"):
454 self._set_autocommit(autocommit)
455 self.autocommit = autocommit
456
457 if autocommit and self.run_commit_hooks_on_set_autocommit_on:
458 self.run_and_clear_commit_hooks()
459 self.run_commit_hooks_on_set_autocommit_on = False
460
461 def get_rollback(self):
462 """Get the "needs rollback" flag -- for *advanced use* only."""
463 if not self.in_atomic_block:
464 raise TransactionManagementError(
465 "The rollback flag doesn't work outside of an 'atomic' block."
466 )
467 return self.needs_rollback
468
469 def set_rollback(self, rollback):
470 """
471 Set or unset the "needs rollback" flag -- for *advanced use* only.
472 """
473 if not self.in_atomic_block:
474 raise TransactionManagementError(
475 "The rollback flag doesn't work outside of an 'atomic' block."
476 )
477 self.needs_rollback = rollback
478
479 def validate_no_atomic_block(self):
480 """Raise an error if an atomic block is active."""
481 if self.in_atomic_block:
482 raise TransactionManagementError(
483 "This is forbidden when an 'atomic' block is active."
484 )
485
486 def validate_no_broken_transaction(self):
487 if self.needs_rollback:
488 raise TransactionManagementError(
489 "An error occurred in the current transaction. You can't "
490 "execute queries until the end of the 'atomic' block."
491 ) from self.rollback_exc
492
493 # ##### Foreign key constraints checks handling #####
494
495 def disable_constraint_checking(self):
496 """
497 Backends can implement as needed to temporarily disable foreign key
498 constraint checking. Should return True if the constraints were
499 disabled and will need to be reenabled.
500 """
501 return False
502
503 def enable_constraint_checking(self):
504 """
505 Backends can implement as needed to re-enable foreign key constraint
506 checking.
507 """
508 pass
509
510 def check_constraints(self, table_names=None):
511 """
512 Backends can override this method if they can apply constraint
513 checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE"). Should raise an
514 IntegrityError if any invalid foreign key references are encountered.
515 """
516 pass
517
518 # ##### Connection termination handling #####
519
520 def is_usable(self):
521 """
522 Test if the database connection is usable.
523
524 This method may assume that self.connection is not None.
525
526 Actual implementations should take care not to raise exceptions
527 as that may prevent Plain from recycling unusable connections.
528 """
529 raise NotImplementedError(
530 "subclasses of BaseDatabaseWrapper may require an is_usable() method"
531 )
532
533 def close_if_health_check_failed(self):
534 """Close existing connection if it fails a health check."""
535 if (
536 self.connection is None
537 or not self.health_check_enabled
538 or self.health_check_done
539 ):
540 return
541
542 if not self.is_usable():
543 self.close()
544 self.health_check_done = True
545
546 def close_if_unusable_or_obsolete(self):
547 """
548 Close the current connection if unrecoverable errors have occurred
549 or if it outlived its maximum age.
550 """
551 if self.connection is not None:
552 self.health_check_done = False
553 # If the application didn't restore the original autocommit setting,
554 # don't take chances, drop the connection.
555 if self.get_autocommit() != self.settings_dict["AUTOCOMMIT"]:
556 self.close()
557 return
558
559 # If an exception other than DataError or IntegrityError occurred
560 # since the last commit / rollback, check if the connection works.
561 if self.errors_occurred:
562 if self.is_usable():
563 self.errors_occurred = False
564 self.health_check_done = True
565 else:
566 self.close()
567 return
568
569 if self.close_at is not None and time.monotonic() >= self.close_at:
570 self.close()
571 return
572
573 # ##### Thread safety handling #####
574
575 @property
576 def allow_thread_sharing(self):
577 with self._thread_sharing_lock:
578 return self._thread_sharing_count > 0
579
580 def validate_thread_sharing(self):
581 """
582 Validate that the connection isn't accessed by another thread than the
583 one which originally created it, unless the connection was explicitly
584 authorized to be shared between threads (via the `inc_thread_sharing()`
585 method). Raise an exception if the validation fails.
586 """
587 if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()):
588 raise DatabaseError(
589 "DatabaseWrapper objects created in a "
590 "thread can only be used in that same thread. The connection "
591 f"was created in thread id {self._thread_ident} and this is "
592 f"thread id {_thread.get_ident()}."
593 )
594
595 # ##### Miscellaneous #####
596
597 def prepare_database(self):
598 """
599 Hook to do any database check or preparation, generally called before
600 migrating a project or an app.
601 """
602 pass
603
604 @cached_property
605 def wrap_database_errors(self):
606 """
607 Context manager and decorator that re-throws backend-specific database
608 exceptions using Plain's common wrappers.
609 """
610 return DatabaseErrorWrapper(self)
611
612 def chunked_cursor(self):
613 """
614 Return a cursor that tries to avoid caching in the database (if
615 supported by the database), otherwise return a regular cursor.
616 """
617 return self.cursor()
618
619 def make_debug_cursor(self, cursor):
620 """Create a cursor that logs all queries in self.queries_log."""
621 return utils.CursorDebugWrapper(cursor, self)
622
623 def make_cursor(self, cursor):
624 """Create a cursor without debug logging."""
625 return utils.CursorWrapper(cursor, self)
626
627 @contextmanager
628 def temporary_connection(self):
629 """
630 Context manager that ensures that a connection is established, and
631 if it opened one, closes it to avoid leaving a dangling connection.
632 This is useful for operations outside of the request-response cycle.
633
634 Provide a cursor: with self.temporary_connection() as cursor: ...
635 """
636 must_close = self.connection is None
637 try:
638 with self.cursor() as cursor:
639 yield cursor
640 finally:
641 if must_close:
642 self.close()
643
644 @contextmanager
645 def _nodb_cursor(self):
646 """
647 Return a cursor from an alternative connection to be used when there is
648 no need to access the main database, specifically for test db
649 creation/deletion. This also prevents the production database from
650 being exposed to potential child threads while (or after) the test
651 database is destroyed. Refs #10868, #17786, #16969.
652 """
653 conn = self.__class__({**self.settings_dict, "NAME": None})
654 try:
655 with conn.cursor() as cursor:
656 yield cursor
657 finally:
658 conn.close()
659
660 def schema_editor(self, *args, **kwargs):
661 """
662 Return a new instance of this backend's SchemaEditor.
663 """
664 if self.SchemaEditorClass is None:
665 raise NotImplementedError(
666 "The SchemaEditorClass attribute of this database wrapper is still None"
667 )
668 return self.SchemaEditorClass(self, *args, **kwargs)
669
670 def on_commit(self, func, robust=False):
671 if not callable(func):
672 raise TypeError("on_commit()'s callback must be a callable.")
673 if self.in_atomic_block:
674 # Transaction in progress; save for execution on commit.
675 self.run_on_commit.append((set(self.savepoint_ids), func, robust))
676 elif not self.get_autocommit():
677 raise TransactionManagementError(
678 "on_commit() cannot be used in manual transaction management"
679 )
680 else:
681 # No transaction in progress and in autocommit mode; execute
682 # immediately.
683 if robust:
684 try:
685 func()
686 except Exception as e:
687 logger.error(
688 f"Error calling {func.__qualname__} in on_commit() (%s).",
689 e,
690 exc_info=True,
691 )
692 else:
693 func()
694
695 def run_and_clear_commit_hooks(self):
696 self.validate_no_atomic_block()
697 current_run_on_commit = self.run_on_commit
698 self.run_on_commit = []
699 while current_run_on_commit:
700 _, func, robust = current_run_on_commit.pop(0)
701 if robust:
702 try:
703 func()
704 except Exception as e:
705 logger.error(
706 f"Error calling {func.__qualname__} in on_commit() during "
707 f"transaction (%s).",
708 e,
709 exc_info=True,
710 )
711 else:
712 func()
713
714 @contextmanager
715 def execute_wrapper(self, wrapper):
716 """
717 Return a context manager under which the wrapper is applied to suitable
718 database query executions.
719 """
720 self.execute_wrappers.append(wrapper)
721 try:
722 yield
723 finally:
724 self.execute_wrappers.pop()
725
726 def copy(self):
727 """
728 Return a copy of this connection.
729
730 For tests that require two connections to the same database.
731 """
732 settings_dict = copy.deepcopy(self.settings_dict)
733 return type(self)(settings_dict)