1from __future__ import annotations
 2
 3from importlib import import_module
 4from threading import local
 5from typing import TYPE_CHECKING, Any, TypedDict
 6
 7from plain.runtime import settings as plain_settings
 8
 9if TYPE_CHECKING:
10    from plain.models.backends.base.base import BaseDatabaseWrapper
11
12
13class DatabaseConfig(TypedDict, total=False):
14    AUTOCOMMIT: bool
15    CONN_MAX_AGE: int | None
16    CONN_HEALTH_CHECKS: bool
17    DISABLE_SERVER_SIDE_CURSORS: bool
18    ENGINE: str
19    HOST: str
20    NAME: str | None
21    OPTIONS: dict[str, Any]
22    PASSWORD: str
23    PORT: str | int
24    TEST: dict[str, Any]
25    TIME_ZONE: str | None
26    USER: str
27
28
29class DatabaseConnection:
30    """Lazy access to the single configured database connection."""
31
32    __slots__ = ("_settings", "_local")
33
34    def __init__(self) -> None:
35        self._settings: DatabaseConfig = {}
36        self._local = local()
37
38    def configure_settings(self) -> DatabaseConfig:
39        database = plain_settings.DATABASE
40
41        database.setdefault("AUTOCOMMIT", True)
42        database.setdefault("CONN_MAX_AGE", 0)
43        database.setdefault("CONN_HEALTH_CHECKS", False)
44        database.setdefault("OPTIONS", {})
45        database.setdefault("TIME_ZONE", None)
46        for setting in ["NAME", "USER", "PASSWORD", "HOST", "PORT"]:
47            database.setdefault(setting, "")
48
49        test_settings = database.setdefault("TEST", {})
50        default_test_settings = [
51            ("CHARSET", None),
52            ("COLLATION", None),
53            ("MIRROR", None),
54            ("NAME", None),
55        ]
56        for key, value in default_test_settings:
57            test_settings.setdefault(key, value)
58
59        return database
60
61    def create_connection(self) -> BaseDatabaseWrapper:
62        database_config = self.configure_settings()
63        backend = import_module(f"{database_config['ENGINE']}.base")
64
65        # Map vendor to wrapper class name
66        vendor_map = {
67            "plain.models.backends.sqlite3": "SQLiteDatabaseWrapper",
68            "plain.models.backends.mysql": "MySQLDatabaseWrapper",
69            "plain.models.backends.postgresql": "PostgreSQLDatabaseWrapper",
70        }
71        wrapper_class_name = vendor_map.get(
72            database_config["ENGINE"], "DatabaseWrapper"
73        )
74        wrapper_class = getattr(backend, wrapper_class_name)
75        return wrapper_class(database_config)
76
77    def has_connection(self) -> bool:
78        return hasattr(self._local, "conn")
79
80    def __getattr__(self, attr: str) -> Any:
81        if not self.has_connection():
82            self._local.conn = self.create_connection()
83
84        return getattr(self._local.conn, attr)
85
86    def __setattr__(self, name: str, value: Any) -> None:
87        if name.startswith("_"):
88            super().__setattr__(name, value)
89        else:
90            if not self.has_connection():
91                self._local.conn = self.create_connection()
92
93            setattr(self._local.conn, name, value)