1from functools import cached_property
2from importlib import import_module
3from threading import local
4from typing import Any, TypedDict
5
6from plain.exceptions import ImproperlyConfigured
7from plain.runtime import settings as plain_settings
8from plain.utils.module_loading import import_string
9
10from .exceptions import ConnectionDoesNotExist
11
12DEFAULT_DB_ALIAS = "default"
13
14
15class DatabaseConfig(TypedDict, total=False):
16 AUTOCOMMIT: bool
17 CONN_MAX_AGE: int | None
18 CONN_HEALTH_CHECKS: bool
19 DISABLE_SERVER_SIDE_CURSORS: bool
20 ENGINE: str
21 HOST: str
22 NAME: str
23 OPTIONS: dict[str, Any] | None
24 PASSWORD: str
25 PORT: str | int
26 TEST: dict[str, Any]
27 TIME_ZONE: str
28 USER: str
29
30
31class ConnectionHandler:
32 """
33 Handler for database connections. Provides lazy connection creation
34 and convenience methods for managing multiple database connections.
35 """
36
37 def __init__(self):
38 self._settings: dict[str, DatabaseConfig] = {}
39 self._connections = local()
40
41 @cached_property
42 def settings(self) -> DatabaseConfig:
43 self._settings = self.configure_settings()
44 return self._settings
45
46 def configure_settings(self) -> DatabaseConfig:
47 databases = plain_settings.DATABASES
48
49 if DEFAULT_DB_ALIAS not in databases:
50 raise ImproperlyConfigured(
51 f"You must define a '{DEFAULT_DB_ALIAS}' database."
52 )
53
54 # Configure default settings.
55 for conn in databases.values():
56 conn.setdefault("AUTOCOMMIT", True)
57 conn.setdefault("CONN_MAX_AGE", 0)
58 conn.setdefault("CONN_HEALTH_CHECKS", False)
59 conn.setdefault("OPTIONS", {})
60 conn.setdefault("TIME_ZONE", None)
61 for setting in ["NAME", "USER", "PASSWORD", "HOST", "PORT"]:
62 conn.setdefault(setting, "")
63
64 test_settings = conn.setdefault("TEST", {})
65 default_test_settings = [
66 ("CHARSET", None),
67 ("COLLATION", None),
68 ("MIRROR", None),
69 ("NAME", None),
70 ]
71 for key, value in default_test_settings:
72 test_settings.setdefault(key, value)
73
74 return databases
75
76 def create_connection(self, alias):
77 database_config = self.settings[alias]
78 backend = import_module(f"{database_config['ENGINE']}.base")
79 return backend.DatabaseWrapper(database_config, alias)
80
81 def __getitem__(self, alias):
82 try:
83 return getattr(self._connections, alias)
84 except AttributeError:
85 if alias not in self.settings:
86 raise ConnectionDoesNotExist(f"The connection '{alias}' doesn't exist.")
87 conn = self.create_connection(alias)
88 setattr(self._connections, alias, conn)
89 return conn
90
91 def __setitem__(self, key, value):
92 setattr(self._connections, key, value)
93
94 def __delitem__(self, key):
95 delattr(self._connections, key)
96
97 def __iter__(self):
98 return iter(self.settings)
99
100 def all(self, initialized_only=False):
101 return [
102 self[alias]
103 for alias in self
104 # If initialized_only is True, return only initialized connections.
105 if not initialized_only or hasattr(self._connections, alias)
106 ]
107
108 def close_all(self):
109 for conn in self.all(initialized_only=True):
110 conn.close()
111
112
113class ConnectionRouter:
114 def __init__(self, routers=None):
115 """
116 If routers is not specified, default to settings.DATABASE_ROUTERS.
117 """
118 self._routers = routers
119
120 @cached_property
121 def routers(self):
122 if self._routers is None:
123 self._routers = plain_settings.DATABASE_ROUTERS
124 routers = []
125 for r in self._routers:
126 if isinstance(r, str):
127 router = import_string(r)()
128 else:
129 router = r
130 routers.append(router)
131 return routers
132
133 def _router_func(action):
134 def _route_db(self, model, **hints):
135 chosen_db = None
136 for router in self.routers:
137 try:
138 method = getattr(router, action)
139 except AttributeError:
140 # If the router doesn't have a method, skip to the next one.
141 pass
142 else:
143 chosen_db = method(model, **hints)
144 if chosen_db:
145 return chosen_db
146 instance = hints.get("instance")
147 if instance is not None and instance._state.db:
148 return instance._state.db
149 return DEFAULT_DB_ALIAS
150
151 return _route_db
152
153 db_for_read = _router_func("db_for_read")
154 db_for_write = _router_func("db_for_write")
155
156 def allow_relation(self, obj1, obj2, **hints):
157 for router in self.routers:
158 try:
159 method = router.allow_relation
160 except AttributeError:
161 # If the router doesn't have a method, skip to the next one.
162 pass
163 else:
164 allow = method(obj1, obj2, **hints)
165 if allow is not None:
166 return allow
167 return obj1._state.db == obj2._state.db
168
169 def allow_migrate(self, db, package_label, **hints):
170 for router in self.routers:
171 try:
172 method = router.allow_migrate
173 except AttributeError:
174 # If the router doesn't have a method, skip to the next one.
175 continue
176
177 allow = method(db, package_label, **hints)
178
179 if allow is not None:
180 return allow
181 return True
182
183 def allow_migrate_model(self, db, model):
184 return self.allow_migrate(
185 db,
186 model._meta.package_label,
187 model_name=model._meta.model_name,
188 model=model,
189 )
190
191 def get_migratable_models(self, models_registry, package_label, db):
192 """Return app models allowed to be migrated on provided db."""
193 models = models_registry.get_models(package_label=package_label)
194 return [model for model in models if self.allow_migrate_model(db, model)]