1from __future__ import annotations
 2
 3from contextvars import ContextVar
 4from typing import TYPE_CHECKING, Any, TypedDict
 5
 6from plain.exceptions import ImproperlyConfigured
 7from plain.runtime import settings as plain_settings
 8
 9if TYPE_CHECKING:
10    from plain.postgres.connection import DatabaseConnection
11
12
13class DatabaseConfig(TypedDict, total=False):
14    CONN_MAX_AGE: int | None
15    CONN_HEALTH_CHECKS: bool
16    HOST: str
17    DATABASE: str | None
18    OPTIONS: dict[str, Any]
19    PASSWORD: str
20    PORT: int | None
21    TEST: dict[str, Any]
22    TIME_ZONE: str | None
23    USER: str
24
25
26# Module-level ContextVar for per-task/per-thread connection storage.
27# Each asyncio.Task gets its own copy (since Python 3.7.1).
28# Thread pool threads maintain their own native context across work items,
29# so connections persist across requests (honoring CONN_MAX_AGE).
30_db_conn: ContextVar[DatabaseConnection | None] = ContextVar("_db_conn", default=None)
31
32
33def _configure_settings() -> DatabaseConfig:
34    if plain_settings.POSTGRES_DATABASE == "":
35        raise ImproperlyConfigured(
36            "The PostgreSQL database has been disabled (DATABASE_URL=none). "
37            "No database operations are available in this context."
38        )
39    if not plain_settings.POSTGRES_DATABASE:  # None or unresolved setting
40        raise ImproperlyConfigured(
41            "PostgreSQL database is not configured. "
42            "Set DATABASE_URL or the individual POSTGRES_* settings."
43        )
44
45    return {
46        "DATABASE": plain_settings.POSTGRES_DATABASE,
47        "USER": plain_settings.POSTGRES_USER,
48        "PASSWORD": plain_settings.POSTGRES_PASSWORD,
49        "HOST": plain_settings.POSTGRES_HOST,
50        "PORT": plain_settings.POSTGRES_PORT,
51        "CONN_MAX_AGE": plain_settings.POSTGRES_CONN_MAX_AGE,
52        "CONN_HEALTH_CHECKS": plain_settings.POSTGRES_CONN_HEALTH_CHECKS,
53        "OPTIONS": plain_settings.POSTGRES_OPTIONS,
54        "TIME_ZONE": plain_settings.POSTGRES_TIME_ZONE,
55        "TEST": {"DATABASE": None},
56    }
57
58
59def _create_connection() -> DatabaseConnection:
60    from plain.postgres.connection import DatabaseConnection
61
62    database_config = _configure_settings()
63    return DatabaseConnection(database_config)
64
65
66def get_connection() -> DatabaseConnection:
67    """Get or create the database connection for the current context."""
68    conn = _db_conn.get()
69    if conn is None:
70        conn = _create_connection()
71        _db_conn.set(conn)
72    return conn
73
74
75def has_connection() -> bool:
76    """Check if a database connection exists in the current context."""
77    return _db_conn.get() is not None