1from __future__ import annotations
2
3from contextvars import ContextVar
4from typing import TYPE_CHECKING, Any, TypedDict
5
6from plain.exceptions import ImproperlyConfigured
7from plain.runtime import settings as plain_settings
8
9if TYPE_CHECKING:
10 from plain.postgres.connection import DatabaseConnection
11
12
13class DatabaseConfig(TypedDict, total=False):
14 CONN_MAX_AGE: int | None
15 CONN_HEALTH_CHECKS: bool
16 HOST: str
17 DATABASE: str | None
18 OPTIONS: dict[str, Any]
19 PASSWORD: str
20 PORT: int | None
21 TEST: dict[str, Any]
22 TIME_ZONE: str | None
23 USER: str
24
25
26# Module-level ContextVar for per-task/per-thread connection storage.
27# Each asyncio.Task gets its own copy (since Python 3.7.1).
28# Thread pool threads maintain their own native context across work items,
29# so connections persist across requests (honoring CONN_MAX_AGE).
30_db_conn: ContextVar[DatabaseConnection | None] = ContextVar("_db_conn", default=None)
31
32
33def _configure_settings() -> DatabaseConfig:
34 if plain_settings.POSTGRES_DATABASE == "":
35 raise ImproperlyConfigured(
36 "The PostgreSQL database has been disabled (DATABASE_URL=none). "
37 "No database operations are available in this context."
38 )
39 if not plain_settings.POSTGRES_DATABASE: # None or unresolved setting
40 raise ImproperlyConfigured(
41 "PostgreSQL database is not configured. "
42 "Set DATABASE_URL or the individual POSTGRES_* settings."
43 )
44
45 return {
46 "DATABASE": plain_settings.POSTGRES_DATABASE,
47 "USER": plain_settings.POSTGRES_USER,
48 "PASSWORD": plain_settings.POSTGRES_PASSWORD,
49 "HOST": plain_settings.POSTGRES_HOST,
50 "PORT": plain_settings.POSTGRES_PORT,
51 "CONN_MAX_AGE": plain_settings.POSTGRES_CONN_MAX_AGE,
52 "CONN_HEALTH_CHECKS": plain_settings.POSTGRES_CONN_HEALTH_CHECKS,
53 "OPTIONS": plain_settings.POSTGRES_OPTIONS,
54 "TIME_ZONE": plain_settings.POSTGRES_TIME_ZONE,
55 "TEST": {"DATABASE": None},
56 }
57
58
59def _create_connection() -> DatabaseConnection:
60 from plain.postgres.connection import DatabaseConnection
61
62 database_config = _configure_settings()
63 return DatabaseConnection(database_config)
64
65
66def get_connection() -> DatabaseConnection:
67 """Get or create the database connection for the current context."""
68 conn = _db_conn.get()
69 if conn is None:
70 conn = _create_connection()
71 _db_conn.set(conn)
72 return conn
73
74
75def has_connection() -> bool:
76 """Check if a database connection exists in the current context."""
77 return _db_conn.get() is not None