1import pkgutil
2from importlib import import_module
3
4from plain import signals
5from plain.exceptions import ImproperlyConfigured
6from plain.runtime import settings
7from plain.utils.connection import BaseConnectionHandler, ConnectionProxy
8from plain.utils.functional import cached_property
9from plain.utils.module_loading import import_string
10
11DEFAULT_DB_ALIAS = "default"
12PLAIN_VERSION_PICKLE_KEY = "_plain_version"
13
14
15class Error(Exception):
16 pass
17
18
19class InterfaceError(Error):
20 pass
21
22
23class DatabaseError(Error):
24 pass
25
26
27class DataError(DatabaseError):
28 pass
29
30
31class OperationalError(DatabaseError):
32 pass
33
34
35class IntegrityError(DatabaseError):
36 pass
37
38
39class InternalError(DatabaseError):
40 pass
41
42
43class ProgrammingError(DatabaseError):
44 pass
45
46
47class NotSupportedError(DatabaseError):
48 pass
49
50
51class DatabaseErrorWrapper:
52 """
53 Context manager and decorator that reraises backend-specific database
54 exceptions using Plain's common wrappers.
55 """
56
57 def __init__(self, wrapper):
58 """
59 wrapper is a database wrapper.
60
61 It must have a Database attribute defining PEP-249 exceptions.
62 """
63 self.wrapper = wrapper
64
65 def __enter__(self):
66 pass
67
68 def __exit__(self, exc_type, exc_value, traceback):
69 if exc_type is None:
70 return
71 for plain_exc_type in (
72 DataError,
73 OperationalError,
74 IntegrityError,
75 InternalError,
76 ProgrammingError,
77 NotSupportedError,
78 DatabaseError,
79 InterfaceError,
80 Error,
81 ):
82 db_exc_type = getattr(self.wrapper.Database, plain_exc_type.__name__)
83 if issubclass(exc_type, db_exc_type):
84 plain_exc_value = plain_exc_type(*exc_value.args)
85 # Only set the 'errors_occurred' flag for errors that may make
86 # the connection unusable.
87 if plain_exc_type not in (DataError, IntegrityError):
88 self.wrapper.errors_occurred = True
89 raise plain_exc_value.with_traceback(traceback) from exc_value
90
91 def __call__(self, func):
92 # Note that we are intentionally not using @wraps here for performance
93 # reasons. Refs #21109.
94 def inner(*args, **kwargs):
95 with self:
96 return func(*args, **kwargs)
97
98 return inner
99
100
101def load_backend(backend_name):
102 """
103 Return a database backend's "base" module given a fully qualified database
104 backend name, or raise an error if it doesn't exist.
105 """
106 try:
107 return import_module(f"{backend_name}.base")
108 except ImportError as e_user:
109 # The database backend wasn't found. Display a helpful error message
110 # listing all built-in database backends.
111 import plain.models.backends
112
113 builtin_backends = [
114 name
115 for _, name, ispkg in pkgutil.iter_modules(plain.models.backends.__path__)
116 if ispkg and name not in {"base", "dummy"}
117 ]
118 if backend_name not in [f"plain.models.backends.{b}" for b in builtin_backends]:
119 backend_reprs = map(repr, sorted(builtin_backends))
120 raise ImproperlyConfigured(
121 "{!r} isn't an available database backend or couldn't be "
122 "imported. Check the above exception. To use one of the "
123 "built-in backends, use 'plain.models.backends.XXX', where XXX "
124 "is one of:\n"
125 " {}".format(backend_name, ", ".join(backend_reprs))
126 ) from e_user
127 else:
128 # If there's some other error, this must be an error in Plain
129 raise
130
131
132class ConnectionHandler(BaseConnectionHandler):
133 settings_name = "DATABASES"
134
135 def configure_settings(self, databases):
136 databases = super().configure_settings(databases)
137 if DEFAULT_DB_ALIAS not in databases:
138 raise ImproperlyConfigured(
139 f"You must define a '{DEFAULT_DB_ALIAS}' database."
140 )
141
142 # Configure default settings.
143 for conn in databases.values():
144 conn.setdefault("AUTOCOMMIT", True)
145 conn.setdefault("CONN_MAX_AGE", 0)
146 conn.setdefault("CONN_HEALTH_CHECKS", False)
147 conn.setdefault("OPTIONS", {})
148 conn.setdefault("TIME_ZONE", None)
149 for setting in ["NAME", "USER", "PASSWORD", "HOST", "PORT"]:
150 conn.setdefault(setting, "")
151
152 test_settings = conn.setdefault("TEST", {})
153 default_test_settings = [
154 ("CHARSET", None),
155 ("COLLATION", None),
156 ("MIRROR", None),
157 ("NAME", None),
158 ]
159 for key, value in default_test_settings:
160 test_settings.setdefault(key, value)
161 return databases
162
163 @property
164 def databases(self):
165 # Maintained for backward compatibility as some 3rd party packages have
166 # made use of this private API in the past. It is no longer used within
167 # Plain itself.
168 return self.settings
169
170 def create_connection(self, alias):
171 db = self.settings[alias]
172 backend = load_backend(db["ENGINE"])
173 return backend.DatabaseWrapper(db, alias)
174
175
176class ConnectionRouter:
177 def __init__(self, routers=None):
178 """
179 If routers is not specified, default to settings.DATABASE_ROUTERS.
180 """
181 self._routers = routers
182
183 @cached_property
184 def routers(self):
185 if self._routers is None:
186 self._routers = settings.DATABASE_ROUTERS
187 routers = []
188 for r in self._routers:
189 if isinstance(r, str):
190 router = import_string(r)()
191 else:
192 router = r
193 routers.append(router)
194 return routers
195
196 def _router_func(action):
197 def _route_db(self, model, **hints):
198 chosen_db = None
199 for router in self.routers:
200 try:
201 method = getattr(router, action)
202 except AttributeError:
203 # If the router doesn't have a method, skip to the next one.
204 pass
205 else:
206 chosen_db = method(model, **hints)
207 if chosen_db:
208 return chosen_db
209 instance = hints.get("instance")
210 if instance is not None and instance._state.db:
211 return instance._state.db
212 return DEFAULT_DB_ALIAS
213
214 return _route_db
215
216 db_for_read = _router_func("db_for_read")
217 db_for_write = _router_func("db_for_write")
218
219 def allow_relation(self, obj1, obj2, **hints):
220 for router in self.routers:
221 try:
222 method = router.allow_relation
223 except AttributeError:
224 # If the router doesn't have a method, skip to the next one.
225 pass
226 else:
227 allow = method(obj1, obj2, **hints)
228 if allow is not None:
229 return allow
230 return obj1._state.db == obj2._state.db
231
232 def allow_migrate(self, db, package_label, **hints):
233 for router in self.routers:
234 try:
235 method = router.allow_migrate
236 except AttributeError:
237 # If the router doesn't have a method, skip to the next one.
238 continue
239
240 allow = method(db, package_label, **hints)
241
242 if allow is not None:
243 return allow
244 return True
245
246 def allow_migrate_model(self, db, model):
247 return self.allow_migrate(
248 db,
249 model._meta.package_label,
250 model_name=model._meta.model_name,
251 model=model,
252 )
253
254 def get_migratable_models(self, models_registry, package_label, db):
255 """Return app models allowed to be migrated on provided db."""
256 models = models_registry.get_models(package_label=package_label)
257 return [model for model in models if self.allow_migrate_model(db, model)]
258
259
260connections = ConnectionHandler()
261
262router = ConnectionRouter()
263
264# For backwards compatibility. Prefer connections['default'] instead.
265connection = ConnectionProxy(connections, DEFAULT_DB_ALIAS)
266
267
268# Register an event to reset saved queries when a Plain request is started.
269def reset_queries(**kwargs):
270 for conn in connections.all(initialized_only=True):
271 conn.queries_log.clear()
272
273
274signals.request_started.connect(reset_queries)
275
276
277# Register an event to reset transaction state and close connections past
278# their lifetime.
279def close_old_connections(**kwargs):
280 for conn in connections.all(initialized_only=True):
281 conn.close_if_unusable_or_obsolete()
282
283
284signals.request_started.connect(close_old_connections)
285signals.request_finished.connect(close_old_connections)