1from __future__ import annotations
2
3from importlib import import_module
4from threading import local
5from typing import TYPE_CHECKING, Any, TypedDict
6
7from plain.runtime import settings as plain_settings
8
9if TYPE_CHECKING:
10 from plain.models.backends.base.base import BaseDatabaseWrapper
11
12
13class DatabaseConfig(TypedDict, total=False):
14 AUTOCOMMIT: bool
15 CONN_MAX_AGE: int | None
16 CONN_HEALTH_CHECKS: bool
17 DISABLE_SERVER_SIDE_CURSORS: bool
18 ENGINE: str
19 HOST: str
20 NAME: str | None
21 OPTIONS: dict[str, Any]
22 PASSWORD: str
23 PORT: str | int
24 TEST: dict[str, Any]
25 TIME_ZONE: str | None
26 USER: str
27
28
29class DatabaseConnection:
30 """Lazy access to the single configured database connection."""
31
32 __slots__ = ("_settings", "_local")
33
34 def __init__(self) -> None:
35 self._settings: DatabaseConfig = {}
36 self._local = local()
37
38 def configure_settings(self) -> DatabaseConfig:
39 database = plain_settings.DATABASE
40
41 database.setdefault("AUTOCOMMIT", True)
42 database.setdefault("CONN_MAX_AGE", 0)
43 database.setdefault("CONN_HEALTH_CHECKS", False)
44 database.setdefault("OPTIONS", {})
45 database.setdefault("TIME_ZONE", None)
46 for setting in ["NAME", "USER", "PASSWORD", "HOST", "PORT"]:
47 database.setdefault(setting, "")
48
49 test_settings = database.setdefault("TEST", {})
50 default_test_settings = [
51 ("CHARSET", None),
52 ("COLLATION", None),
53 ("MIRROR", None),
54 ("NAME", None),
55 ]
56 for key, value in default_test_settings:
57 test_settings.setdefault(key, value)
58
59 return database
60
61 def create_connection(self) -> BaseDatabaseWrapper:
62 database_config = self.configure_settings()
63 backend = import_module(f"{database_config['ENGINE']}.base")
64
65 # Map vendor to wrapper class name
66 vendor_map = {
67 "plain.models.backends.sqlite3": "SQLiteDatabaseWrapper",
68 "plain.models.backends.mysql": "MySQLDatabaseWrapper",
69 "plain.models.backends.postgresql": "PostgreSQLDatabaseWrapper",
70 }
71 wrapper_class_name = vendor_map.get(
72 database_config["ENGINE"], "DatabaseWrapper"
73 )
74 wrapper_class = getattr(backend, wrapper_class_name)
75 return wrapper_class(database_config)
76
77 def has_connection(self) -> bool:
78 return hasattr(self._local, "conn")
79
80 def __getattr__(self, attr: str) -> Any:
81 if not self.has_connection():
82 self._local.conn = self.create_connection()
83
84 return getattr(self._local.conn, attr)
85
86 def __setattr__(self, name: str, value: Any) -> None:
87 if name.startswith("_"):
88 super().__setattr__(name, value)
89 else:
90 if not self.has_connection():
91 self._local.conn = self.create_connection()
92
93 setattr(self._local.conn, name, value)