1from __future__ import annotations
  2
  3import functools
  4import logging
  5import time
  6from collections.abc import Generator, Iterator, Mapping, Sequence
  7from contextlib import contextmanager
  8from hashlib import md5
  9from types import TracebackType
 10from typing import TYPE_CHECKING, Any, Self
 11
 12import psycopg
 13
 14from plain.postgres.db import NotSupportedError
 15from plain.postgres.otel import db_span
 16from plain.utils.dateparse import parse_time
 17
 18if TYPE_CHECKING:
 19    from plain.postgres.connection import DatabaseConnection
 20
 21logger = logging.getLogger("plain.postgres.utils")
 22
 23
 24def make_model_tuple(model: Any) -> tuple[str, str]:
 25    """
 26    Take a model or a string of the form "package_label.ModelName" and return a
 27    corresponding ("package_label", "modelname") tuple. If a tuple is passed in,
 28    assume it's a valid model tuple already and return it unchanged.
 29    """
 30    try:
 31        if isinstance(model, tuple):
 32            model_tuple = model
 33        elif isinstance(model, str):
 34            package_label, model_name = model.split(".")
 35            model_tuple = package_label, model_name.lower()
 36        else:
 37            model_tuple = (
 38                model.model_options.package_label,
 39                model.model_options.model_name,
 40            )
 41        assert len(model_tuple) == 2
 42        return model_tuple
 43    except (ValueError, AssertionError):
 44        raise ValueError(
 45            f"Invalid model reference '{model}'. String model references "
 46            "must be of the form 'package_label.ModelName'."
 47        )
 48
 49
 50def resolve_callables(
 51    mapping: dict[str, Any],
 52) -> Generator[tuple[str, Any]]:
 53    """
 54    Generate key/value pairs for the given mapping where the values are
 55    evaluated if they're callable.
 56    """
 57    for k, v in mapping.items():
 58        yield k, v() if callable(v) else v
 59
 60
 61class CursorWrapper:
 62    def __init__(self, cursor: Any, db: DatabaseConnection) -> None:
 63        self.cursor = cursor
 64        self.db = db
 65
 66    WRAP_ERROR_ATTRS = frozenset(["nextset"])
 67
 68    def __getattr__(self, attr: str) -> Any:
 69        cursor_attr = getattr(self.cursor, attr)
 70        if attr in CursorWrapper.WRAP_ERROR_ATTRS:
 71            return self.db.wrap_database_errors(cursor_attr)
 72        else:
 73            return cursor_attr
 74
 75    def __iter__(self) -> Iterator[tuple[Any, ...]]:
 76        with self.db.wrap_database_errors:
 77            yield from self.cursor
 78
 79    def fetchone(self) -> tuple[Any, ...] | None:
 80        with self.db.wrap_database_errors:
 81            return self.cursor.fetchone()
 82
 83    def fetchmany(self, size: int | None = None) -> list[tuple[Any, ...]]:
 84        with self.db.wrap_database_errors:
 85            if size is None:
 86                return self.cursor.fetchmany()
 87            return self.cursor.fetchmany(size)
 88
 89    def fetchall(self) -> list[tuple[Any, ...]]:
 90        with self.db.wrap_database_errors:
 91            return self.cursor.fetchall()
 92
 93    def __enter__(self) -> Self:
 94        return self
 95
 96    def __exit__(
 97        self,
 98        type: type[BaseException] | None,
 99        value: BaseException | None,
100        traceback: TracebackType | None,
101    ) -> None:
102        # Close instead of passing through to avoid backend-specific behavior
103        # (#17671). Catch errors liberally because errors in cleanup code
104        # aren't useful.
105        try:
106            self.close()
107        except psycopg.Error:
108            pass
109
110    def stream(
111        self, sql: str, params: Sequence[Any] | None = None
112    ) -> Generator[tuple[Any, ...]]:
113        self.db.validate_no_broken_transaction()
114        with db_span(self.db, sql, params=params):
115            with self.db.wrap_database_errors:
116                try:
117                    if params is None:
118                        yield from self.cursor.stream(sql)
119                    else:
120                        yield from self.cursor.stream(sql, params)
121                finally:
122                    try:
123                        self.close()
124                    except psycopg.Error:
125                        pass
126
127    # The following methods cannot be implemented in __getattr__, because the
128    # code must run when the method is invoked, not just when it is accessed.
129
130    def callproc(
131        self,
132        procname: str,
133        params: Sequence[Any] | None = None,
134        kparams: Mapping[str, Any] | None = None,
135    ) -> Any:
136        # Keyword parameters for callproc aren't supported in PEP 249.
137        # PostgreSQL's psycopg doesn't support them either.
138        if kparams is not None:
139            raise NotSupportedError(
140                "Keyword parameters for callproc are not supported."
141            )
142        self.db.validate_no_broken_transaction()
143        with self.db.wrap_database_errors:
144            if params is None:
145                return self.cursor.callproc(procname)
146            return self.cursor.callproc(procname, params)
147
148    def execute(
149        self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
150    ) -> Self:
151        return self._execute_with_wrappers(
152            sql, params, many=False, executor=self._execute
153        )
154
155    def executemany(self, sql: str, param_list: Sequence[Sequence[Any]]) -> Self:
156        return self._execute_with_wrappers(
157            sql, param_list, many=True, executor=self._executemany
158        )
159
160    def _execute_with_wrappers(
161        self, sql: str, params: Any, many: bool, executor: Any
162    ) -> Self:
163        context: dict[str, Any] = {"connection": self.db, "cursor": self}
164        for wrapper in reversed(self.db.execute_wrappers):
165            executor = functools.partial(wrapper, executor)
166        executor(sql, params, many, context)
167        return self
168
169    def _execute(self, sql: str, params: Any, *ignored_wrapper_args: Any) -> None:
170        # Wrap in an OpenTelemetry span with standard attributes.
171        with db_span(self.db, sql, params=params):
172            self.db.validate_no_broken_transaction()
173            with self.db.wrap_database_errors:
174                if params is None:
175                    self.cursor.execute(sql)
176                else:
177                    self.cursor.execute(sql, params)
178
179    def _executemany(
180        self, sql: str, param_list: Any, *ignored_wrapper_args: Any
181    ) -> None:
182        with db_span(self.db, sql, many=True, params=param_list):
183            self.db.validate_no_broken_transaction()
184            with self.db.wrap_database_errors:
185                self.cursor.executemany(sql, param_list)
186
187
188class CursorDebugWrapper(CursorWrapper):
189    # XXX callproc isn't instrumented at this time.
190
191    def stream(
192        self, sql: str, params: Sequence[Any] | None = None
193    ) -> Generator[tuple[Any, ...]]:
194        with self.debug_sql(sql, params, use_last_executed_query=True):
195            yield from super().stream(sql, params)
196
197    def execute(
198        self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
199    ) -> Self:
200        with self.debug_sql(sql, params, use_last_executed_query=True):
201            super().execute(sql, params)
202        return self
203
204    def executemany(self, sql: str, param_list: Sequence[Sequence[Any]]) -> Self:
205        with self.debug_sql(sql, param_list, many=True):
206            super().executemany(sql, param_list)
207        return self
208
209    @contextmanager
210    def debug_sql(
211        self,
212        sql: str | None = None,
213        params: Any = None,
214        use_last_executed_query: bool = False,
215        many: bool = False,
216    ) -> Generator[None]:
217        start = time.monotonic()
218        try:
219            yield
220        finally:
221            stop = time.monotonic()
222            duration = stop - start
223            if use_last_executed_query:
224                sql = self.db.last_executed_query(self.cursor, sql, params)  # type: ignore[arg-type]
225            try:
226                times = len(params) if many else ""
227            except TypeError:
228                # params could be an iterator.
229                times = "?"
230            self.db.queries_log.append(
231                {
232                    "sql": f"{times} times: {sql}" if many else sql,
233                    "time": f"{duration:.3f}",
234                }
235            )
236            logger.debug(
237                "(%.3f) %s; args=%s",
238                duration,
239                sql,
240                params,
241                extra={
242                    "duration": duration,
243                    "sql": sql,
244                    "params": params,
245                },
246            )
247
248
249@contextmanager
250def debug_transaction(connection: DatabaseConnection, sql: str) -> Generator[None]:
251    start = time.monotonic()
252    try:
253        yield
254    finally:
255        if connection.queries_logged:
256            stop = time.monotonic()
257            duration = stop - start
258            connection.queries_log.append(
259                {
260                    "sql": f"{sql}",
261                    "time": f"{duration:.3f}",
262                }
263            )
264            logger.debug(
265                "(%.3f) %s; args=%s",
266                duration,
267                sql,
268                None,
269                extra={
270                    "duration": duration,
271                    "sql": sql,
272                },
273            )
274
275
276def split_tzname_delta(tzname: str) -> tuple[str, str | None, str | None]:
277    """
278    Split a time zone name into a 3-tuple of (name, sign, offset).
279    """
280    for sign in ["+", "-"]:
281        if sign in tzname:
282            name, offset = tzname.rsplit(sign, 1)
283            if offset and parse_time(offset):
284                return name, sign, offset
285    return tzname, None, None
286
287
288###############################################
289# Converters from Python to database (string) #
290###############################################
291
292
293def split_identifier(identifier: str) -> tuple[str, str]:
294    """
295    Split an SQL identifier into a two element tuple of (namespace, name).
296
297    The identifier could be a table, column, or sequence name might be prefixed
298    by a namespace.
299    """
300    try:
301        namespace, name = identifier.split('"."')
302    except ValueError:
303        namespace, name = "", identifier
304    return namespace.strip('"'), name.strip('"')
305
306
307def truncate_name(identifier: str, length: int | None = None, hash_len: int = 4) -> str:
308    """
309    Shorten an SQL identifier to a repeatable mangled version with the given
310    length.
311
312    If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
313    truncate the table portion only.
314    """
315    namespace, name = split_identifier(identifier)
316
317    if length is None or len(name) <= length:
318        return identifier
319
320    digest = names_digest(name, length=hash_len)
321    return "{}{}{}".format(
322        f'{namespace}"."' if namespace else "",
323        name[: length - hash_len],
324        digest,
325    )
326
327
328def names_digest(*args: str, length: int) -> str:
329    """
330    Generate a 32-bit digest of a set of arguments that can be used to shorten
331    identifying names.
332    """
333    h = md5(usedforsecurity=False)
334    for arg in args:
335        h.update(arg.encode())
336    return h.hexdigest()[:length]
337
338
339def strip_quotes(table_name: str) -> str:
340    """
341    Strip quotes off of quoted table names to make them safe for use in index
342    names, sequence names, etc.
343    """
344    has_quotes = table_name.startswith('"') and table_name.endswith('"')
345    return table_name[1:-1] if has_quotes else table_name