1"""Where `DatabaseConnection` gets its psycopg connection — direct per-use
  2(`DirectSource`) or checkout/return against a shared pool (`PoolSource`).
  3The wrapper calls `source.acquire()` / `source.release()` / `source.config`
  4and is otherwise source-agnostic."""
  5
  6from __future__ import annotations
  7
  8import threading
  9import time
 10from abc import ABC, abstractmethod
 11from typing import TYPE_CHECKING, Any
 12
 13import psycopg
 14from psycopg_pool import ConnectionPool, PoolTimeout
 15
 16from plain.exceptions import ImproperlyConfigured
 17from plain.logs import get_framework_logger
 18from plain.postgres.adapters import get_adapters_template
 19from plain.postgres.database_url import DatabaseConfig, parse_database_url
 20from plain.postgres.dialect import MAX_NAME_LENGTH
 21from plain.postgres.otel import (
 22    record_connection_acquire,
 23    record_connection_release,
 24    record_connection_timeout,
 25)
 26from plain.runtime import settings as plain_settings
 27
 28logger = get_framework_logger()
 29
 30if TYPE_CHECKING:
 31    from psycopg import Connection as PsycopgConnection
 32
 33
 34def build_connection_params(config: DatabaseConfig) -> dict[str, Any]:
 35    """Return kwargs suitable for `psycopg.connect()` from a `DatabaseConfig`.
 36
 37    Every psycopg connection Plain opens — pooled or direct — goes through
 38    this function so they all share the same adapters, cursor factory, and
 39    validation rules.
 40    """
 41    options = config.get("OPTIONS", {})
 42    db_name = config["DATABASE"]
 43    if len(db_name) > MAX_NAME_LENGTH:
 44        raise ImproperlyConfigured(
 45            f"The database name {db_name!r} ({len(db_name)} characters) is longer "
 46            f"than PostgreSQL's limit of {MAX_NAME_LENGTH} characters. Supply a "
 47            "shorter database name in POSTGRES_URL."
 48        )
 49    conn_params: dict[str, Any] = {"dbname": db_name, **options}
 50    if config.get("USER"):
 51        conn_params["user"] = config["USER"]
 52    if config.get("PASSWORD"):
 53        conn_params["password"] = config["PASSWORD"]
 54    if config.get("HOST"):
 55        conn_params["host"] = config["HOST"]
 56    if config.get("PORT"):
 57        conn_params["port"] = config["PORT"]
 58    conn_params["context"] = get_adapters_template()
 59    # ClientCursor does client-side parameter binding and issues no
 60    # server-side prepared statements — safe behind transaction-mode
 61    # poolers like pgbouncer.
 62    conn_params["cursor_factory"] = psycopg.ClientCursor
 63    conn_params["prepare_threshold"] = conn_params.pop("prepare_threshold", None)
 64    return conn_params
 65
 66
 67class ConnectionSource(ABC):
 68    @property
 69    @abstractmethod
 70    def config(self) -> DatabaseConfig:
 71        """What server this source connects to. Read by otel, psql helper, maintenance."""
 72
 73    @abstractmethod
 74    def acquire(self) -> PsycopgConnection[Any]: ...
 75
 76    @abstractmethod
 77    def release(self, conn: PsycopgConnection[Any]) -> None: ...
 78
 79
 80class DirectSource(ConnectionSource):
 81    """Opens a fresh psycopg connection per acquire; closes on release."""
 82
 83    def __init__(self, config: DatabaseConfig):
 84        self._config = config
 85        self._params = build_connection_params(config)
 86
 87    @property
 88    def config(self) -> DatabaseConfig:
 89        return self._config
 90
 91    def acquire(self) -> PsycopgConnection[Any]:
 92        return psycopg.connect(**self._params)
 93
 94    def release(self, conn: PsycopgConnection[Any]) -> None:
 95        conn.close()
 96
 97
 98class PoolSource(ConnectionSource):
 99    """Lazily-opened `psycopg_pool.ConnectionPool`. `close()` drops the pool
100    so the next acquire rebuilds against current settings.
101
102    The `name` is used as the `db.client.connection.pool.name` attribute on
103    the `db.client.connection.*` OpenTelemetry metric family.
104    """
105
106    def __init__(self, name: str = "runtime") -> None:
107        self.name = name
108        self._pool: ConnectionPool | None = None
109        self._config: DatabaseConfig | None = None
110        self._lock = threading.Lock()
111
112    @property
113    def config(self) -> DatabaseConfig:
114        if self._config is None:
115            # Opening the pool populates _config as a side effect; until then,
116            # parse lazily so callers that only need config (otel on a no-op
117            # request) don't force the pool open.
118            self._config = _parse_runtime_url()
119        return self._config
120
121    def acquire(self) -> PsycopgConnection[Any]:
122        pool = self._get_pool()
123        start = time.perf_counter()
124        try:
125            conn = pool.getconn()
126        except PoolTimeout:
127            record_connection_timeout(self.name)
128            raise
129        checkout_time = time.perf_counter()
130        record_connection_acquire(self.name, conn, checkout_time - start, checkout_time)
131        return conn
132
133    def release(self, conn: PsycopgConnection[Any]) -> None:
134        record_connection_release(self.name, conn, time.perf_counter())
135        pool = self._pool
136        if pool is None:
137            conn.close()
138            return
139        try:
140            pool.putconn(conn)
141        except Exception:
142            logger.debug("Error returning connection to pool", exc_info=True)
143            conn.close()
144
145    def get_stats(self) -> dict[str, int] | None:
146        """Return pool statistics, or None if the pool is closed."""
147        pool = self._pool
148        if pool is None:
149            return None
150        try:
151            return pool.get_stats()
152        except Exception:
153            return None
154
155    def close(self, timeout: float = 5.0) -> None:
156        with self._lock:
157            self._config = None
158            if self._pool is not None:
159                try:
160                    self._pool.close(timeout=timeout)
161                finally:
162                    self._pool = None
163
164    def _get_pool(self) -> ConnectionPool:
165        if self._pool is None:
166            with self._lock:
167                if self._pool is None:
168                    self._pool = self._open_pool()
169        return self._pool
170
171    def _open_pool(self) -> ConnectionPool:
172        self._config = _parse_runtime_url()
173        params = build_connection_params(self._config)
174        pool = ConnectionPool(
175            kwargs=params,
176            open=False,
177            reset=_reset_pooled_connection,
178            min_size=plain_settings.POSTGRES_POOL_MIN_SIZE,
179            max_size=plain_settings.POSTGRES_POOL_MAX_SIZE,
180            max_lifetime=plain_settings.POSTGRES_POOL_MAX_LIFETIME,
181            timeout=plain_settings.POSTGRES_POOL_TIMEOUT,
182        )
183        pool.open(wait=False)
184        return pool
185
186
187def _parse_runtime_url() -> DatabaseConfig:
188    """Validate `POSTGRES_URL` and return its parsed config.
189
190    Raises `ImproperlyConfigured` with a friendly message when the URL is
191    empty or explicitly disabled, so callers that only need the config (like
192    `plain postgres shell`) don't fall through to a raw `ValueError`.
193    """
194    url = str(plain_settings.POSTGRES_URL)
195    if not url:
196        raise ImproperlyConfigured(
197            "PostgreSQL database is not configured. "
198            "Set POSTGRES_URL (or DATABASE_URL) to a postgres://... connection string."
199        )
200    if url.lower() == "none":
201        raise ImproperlyConfigured(
202            "The PostgreSQL database has been disabled (POSTGRES_URL=none). "
203            "No database operations are available in this context."
204        )
205    return parse_database_url(url)
206
207
208def _reset_pooled_connection(conn: PsycopgConnection[Any]) -> None:
209    """Ensure a connection is clean before returning to the pool.
210
211    Rolls back any in-progress transaction and restores autocommit=True so
212    the next checkout starts in a known state. Raising here signals the pool
213    to discard the connection.
214    """
215    if not conn.autocommit:
216        conn.rollback()
217        conn.autocommit = True
218
219
220# Process-wide singleton. Pool is lazy-opened on first acquire.
221runtime_pool_source = PoolSource(name="runtime")