Plain is headed towards 1.0! Subscribe for development updates →

  1import pkgutil
  2from importlib import import_module
  3
  4from plain import signals
  5from plain.exceptions import ImproperlyConfigured
  6from plain.runtime import settings
  7from plain.utils.connection import BaseConnectionHandler, ConnectionProxy
  8from plain.utils.functional import cached_property
  9from plain.utils.module_loading import import_string
 10
 11DEFAULT_DB_ALIAS = "default"
 12PLAIN_VERSION_PICKLE_KEY = "_plain_version"
 13
 14
 15class Error(Exception):
 16    pass
 17
 18
 19class InterfaceError(Error):
 20    pass
 21
 22
 23class DatabaseError(Error):
 24    pass
 25
 26
 27class DataError(DatabaseError):
 28    pass
 29
 30
 31class OperationalError(DatabaseError):
 32    pass
 33
 34
 35class IntegrityError(DatabaseError):
 36    pass
 37
 38
 39class InternalError(DatabaseError):
 40    pass
 41
 42
 43class ProgrammingError(DatabaseError):
 44    pass
 45
 46
 47class NotSupportedError(DatabaseError):
 48    pass
 49
 50
 51class DatabaseErrorWrapper:
 52    """
 53    Context manager and decorator that reraises backend-specific database
 54    exceptions using Plain's common wrappers.
 55    """
 56
 57    def __init__(self, wrapper):
 58        """
 59        wrapper is a database wrapper.
 60
 61        It must have a Database attribute defining PEP-249 exceptions.
 62        """
 63        self.wrapper = wrapper
 64
 65    def __enter__(self):
 66        pass
 67
 68    def __exit__(self, exc_type, exc_value, traceback):
 69        if exc_type is None:
 70            return
 71        for plain_exc_type in (
 72            DataError,
 73            OperationalError,
 74            IntegrityError,
 75            InternalError,
 76            ProgrammingError,
 77            NotSupportedError,
 78            DatabaseError,
 79            InterfaceError,
 80            Error,
 81        ):
 82            db_exc_type = getattr(self.wrapper.Database, plain_exc_type.__name__)
 83            if issubclass(exc_type, db_exc_type):
 84                plain_exc_value = plain_exc_type(*exc_value.args)
 85                # Only set the 'errors_occurred' flag for errors that may make
 86                # the connection unusable.
 87                if plain_exc_type not in (DataError, IntegrityError):
 88                    self.wrapper.errors_occurred = True
 89                raise plain_exc_value.with_traceback(traceback) from exc_value
 90
 91    def __call__(self, func):
 92        # Note that we are intentionally not using @wraps here for performance
 93        # reasons. Refs #21109.
 94        def inner(*args, **kwargs):
 95            with self:
 96                return func(*args, **kwargs)
 97
 98        return inner
 99
100
101def load_backend(backend_name):
102    """
103    Return a database backend's "base" module given a fully qualified database
104    backend name, or raise an error if it doesn't exist.
105    """
106    try:
107        return import_module(f"{backend_name}.base")
108    except ImportError as e_user:
109        # The database backend wasn't found. Display a helpful error message
110        # listing all built-in database backends.
111        import plain.models.backends
112
113        builtin_backends = [
114            name
115            for _, name, ispkg in pkgutil.iter_modules(plain.models.backends.__path__)
116            if ispkg and name not in {"base", "dummy"}
117        ]
118        if backend_name not in [f"plain.models.backends.{b}" for b in builtin_backends]:
119            backend_reprs = map(repr, sorted(builtin_backends))
120            raise ImproperlyConfigured(
121                "{!r} isn't an available database backend or couldn't be "
122                "imported. Check the above exception. To use one of the "
123                "built-in backends, use 'plain.models.backends.XXX', where XXX "
124                "is one of:\n"
125                "    {}".format(backend_name, ", ".join(backend_reprs))
126            ) from e_user
127        else:
128            # If there's some other error, this must be an error in Plain
129            raise
130
131
132class ConnectionHandler(BaseConnectionHandler):
133    settings_name = "DATABASES"
134
135    def configure_settings(self, databases):
136        databases = super().configure_settings(databases)
137        if DEFAULT_DB_ALIAS not in databases:
138            raise ImproperlyConfigured(
139                f"You must define a '{DEFAULT_DB_ALIAS}' database."
140            )
141
142        # Configure default settings.
143        for conn in databases.values():
144            conn.setdefault("AUTOCOMMIT", True)
145            conn.setdefault("CONN_MAX_AGE", 0)
146            conn.setdefault("CONN_HEALTH_CHECKS", False)
147            conn.setdefault("OPTIONS", {})
148            conn.setdefault("TIME_ZONE", None)
149            for setting in ["NAME", "USER", "PASSWORD", "HOST", "PORT"]:
150                conn.setdefault(setting, "")
151
152            test_settings = conn.setdefault("TEST", {})
153            default_test_settings = [
154                ("CHARSET", None),
155                ("COLLATION", None),
156                ("MIRROR", None),
157                ("NAME", None),
158            ]
159            for key, value in default_test_settings:
160                test_settings.setdefault(key, value)
161        return databases
162
163    @property
164    def databases(self):
165        # Maintained for backward compatibility as some 3rd party packages have
166        # made use of this private API in the past. It is no longer used within
167        # Plain itself.
168        return self.settings
169
170    def create_connection(self, alias):
171        db = self.settings[alias]
172        backend = load_backend(db["ENGINE"])
173        return backend.DatabaseWrapper(db, alias)
174
175
176class ConnectionRouter:
177    def __init__(self, routers=None):
178        """
179        If routers is not specified, default to settings.DATABASE_ROUTERS.
180        """
181        self._routers = routers
182
183    @cached_property
184    def routers(self):
185        if self._routers is None:
186            self._routers = settings.DATABASE_ROUTERS
187        routers = []
188        for r in self._routers:
189            if isinstance(r, str):
190                router = import_string(r)()
191            else:
192                router = r
193            routers.append(router)
194        return routers
195
196    def _router_func(action):
197        def _route_db(self, model, **hints):
198            chosen_db = None
199            for router in self.routers:
200                try:
201                    method = getattr(router, action)
202                except AttributeError:
203                    # If the router doesn't have a method, skip to the next one.
204                    pass
205                else:
206                    chosen_db = method(model, **hints)
207                    if chosen_db:
208                        return chosen_db
209            instance = hints.get("instance")
210            if instance is not None and instance._state.db:
211                return instance._state.db
212            return DEFAULT_DB_ALIAS
213
214        return _route_db
215
216    db_for_read = _router_func("db_for_read")
217    db_for_write = _router_func("db_for_write")
218
219    def allow_relation(self, obj1, obj2, **hints):
220        for router in self.routers:
221            try:
222                method = router.allow_relation
223            except AttributeError:
224                # If the router doesn't have a method, skip to the next one.
225                pass
226            else:
227                allow = method(obj1, obj2, **hints)
228                if allow is not None:
229                    return allow
230        return obj1._state.db == obj2._state.db
231
232    def allow_migrate(self, db, package_label, **hints):
233        for router in self.routers:
234            try:
235                method = router.allow_migrate
236            except AttributeError:
237                # If the router doesn't have a method, skip to the next one.
238                continue
239
240            allow = method(db, package_label, **hints)
241
242            if allow is not None:
243                return allow
244        return True
245
246    def allow_migrate_model(self, db, model):
247        return self.allow_migrate(
248            db,
249            model._meta.package_label,
250            model_name=model._meta.model_name,
251            model=model,
252        )
253
254    def get_migratable_models(self, models_registry, package_label, db):
255        """Return app models allowed to be migrated on provided db."""
256        models = models_registry.get_models(package_label=package_label)
257        return [model for model in models if self.allow_migrate_model(db, model)]
258
259
260connections = ConnectionHandler()
261
262router = ConnectionRouter()
263
264# For backwards compatibility. Prefer connections['default'] instead.
265connection = ConnectionProxy(connections, DEFAULT_DB_ALIAS)
266
267
268# Register an event to reset saved queries when a Plain request is started.
269def reset_queries(**kwargs):
270    for conn in connections.all(initialized_only=True):
271        conn.queries_log.clear()
272
273
274signals.request_started.connect(reset_queries)
275
276
277# Register an event to reset transaction state and close connections past
278# their lifetime.
279def close_old_connections(**kwargs):
280    for conn in connections.all(initialized_only=True):
281        conn.close_if_unusable_or_obsolete()
282
283
284signals.request_started.connect(close_old_connections)
285signals.request_finished.connect(close_old_connections)