Plain is headed towards 1.0! Subscribe for development updates →

  1from functools import cached_property
  2from importlib import import_module
  3from threading import local
  4from typing import Any, TypedDict
  5
  6from plain.exceptions import ImproperlyConfigured
  7from plain.runtime import settings as plain_settings
  8from plain.utils.module_loading import import_string
  9
 10from .exceptions import ConnectionDoesNotExist
 11
 12DEFAULT_DB_ALIAS = "default"
 13
 14
 15class DatabaseConfig(TypedDict, total=False):
 16    AUTOCOMMIT: bool
 17    CONN_MAX_AGE: int | None
 18    CONN_HEALTH_CHECKS: bool
 19    DISABLE_SERVER_SIDE_CURSORS: bool
 20    ENGINE: str
 21    HOST: str
 22    NAME: str
 23    OPTIONS: dict[str, Any] | None
 24    PASSWORD: str
 25    PORT: str | int
 26    TEST: dict[str, Any]
 27    TIME_ZONE: str
 28    USER: str
 29
 30
 31class ConnectionHandler:
 32    """
 33    Handler for database connections. Provides lazy connection creation
 34    and convenience methods for managing multiple database connections.
 35    """
 36
 37    def __init__(self):
 38        self._settings: dict[str, DatabaseConfig] = {}
 39        self._connections = local()
 40
 41    @cached_property
 42    def settings(self) -> DatabaseConfig:
 43        self._settings = self.configure_settings()
 44        return self._settings
 45
 46    def configure_settings(self) -> DatabaseConfig:
 47        databases = plain_settings.DATABASES
 48
 49        if DEFAULT_DB_ALIAS not in databases:
 50            raise ImproperlyConfigured(
 51                f"You must define a '{DEFAULT_DB_ALIAS}' database."
 52            )
 53
 54        # Configure default settings.
 55        for conn in databases.values():
 56            conn.setdefault("AUTOCOMMIT", True)
 57            conn.setdefault("CONN_MAX_AGE", 0)
 58            conn.setdefault("CONN_HEALTH_CHECKS", False)
 59            conn.setdefault("OPTIONS", {})
 60            conn.setdefault("TIME_ZONE", None)
 61            for setting in ["NAME", "USER", "PASSWORD", "HOST", "PORT"]:
 62                conn.setdefault(setting, "")
 63
 64            test_settings = conn.setdefault("TEST", {})
 65            default_test_settings = [
 66                ("CHARSET", None),
 67                ("COLLATION", None),
 68                ("MIRROR", None),
 69                ("NAME", None),
 70            ]
 71            for key, value in default_test_settings:
 72                test_settings.setdefault(key, value)
 73
 74        return databases
 75
 76    def create_connection(self, alias):
 77        database_config = self.settings[alias]
 78        backend = import_module(f"{database_config['ENGINE']}.base")
 79        return backend.DatabaseWrapper(database_config, alias)
 80
 81    def __getitem__(self, alias):
 82        try:
 83            return getattr(self._connections, alias)
 84        except AttributeError:
 85            if alias not in self.settings:
 86                raise ConnectionDoesNotExist(f"The connection '{alias}' doesn't exist.")
 87        conn = self.create_connection(alias)
 88        setattr(self._connections, alias, conn)
 89        return conn
 90
 91    def __setitem__(self, key, value):
 92        setattr(self._connections, key, value)
 93
 94    def __delitem__(self, key):
 95        delattr(self._connections, key)
 96
 97    def __iter__(self):
 98        return iter(self.settings)
 99
100    def all(self, initialized_only=False):
101        return [
102            self[alias]
103            for alias in self
104            # If initialized_only is True, return only initialized connections.
105            if not initialized_only or hasattr(self._connections, alias)
106        ]
107
108    def close_all(self):
109        for conn in self.all(initialized_only=True):
110            conn.close()
111
112
113class ConnectionRouter:
114    def __init__(self, routers=None):
115        """
116        If routers is not specified, default to settings.DATABASE_ROUTERS.
117        """
118        self._routers = routers
119
120    @cached_property
121    def routers(self):
122        if self._routers is None:
123            self._routers = plain_settings.DATABASE_ROUTERS
124        routers = []
125        for r in self._routers:
126            if isinstance(r, str):
127                router = import_string(r)()
128            else:
129                router = r
130            routers.append(router)
131        return routers
132
133    def _router_func(action):
134        def _route_db(self, model, **hints):
135            chosen_db = None
136            for router in self.routers:
137                try:
138                    method = getattr(router, action)
139                except AttributeError:
140                    # If the router doesn't have a method, skip to the next one.
141                    pass
142                else:
143                    chosen_db = method(model, **hints)
144                    if chosen_db:
145                        return chosen_db
146            instance = hints.get("instance")
147            if instance is not None and instance._state.db:
148                return instance._state.db
149            return DEFAULT_DB_ALIAS
150
151        return _route_db
152
153    db_for_read = _router_func("db_for_read")
154    db_for_write = _router_func("db_for_write")
155
156    def allow_relation(self, obj1, obj2, **hints):
157        for router in self.routers:
158            try:
159                method = router.allow_relation
160            except AttributeError:
161                # If the router doesn't have a method, skip to the next one.
162                pass
163            else:
164                allow = method(obj1, obj2, **hints)
165                if allow is not None:
166                    return allow
167        return obj1._state.db == obj2._state.db
168
169    def allow_migrate(self, db, package_label, **hints):
170        for router in self.routers:
171            try:
172                method = router.allow_migrate
173            except AttributeError:
174                # If the router doesn't have a method, skip to the next one.
175                continue
176
177            allow = method(db, package_label, **hints)
178
179            if allow is not None:
180                return allow
181        return True
182
183    def allow_migrate_model(self, db, model):
184        return self.allow_migrate(
185            db,
186            model._meta.package_label,
187            model_name=model._meta.model_name,
188            model=model,
189        )
190
191    def get_migratable_models(self, models_registry, package_label, db):
192        """Return app models allowed to be migrated on provided db."""
193        models = models_registry.get_models(package_label=package_label)
194        return [model for model in models if self.allow_migrate_model(db, model)]