1from __future__ import annotations
  2
  3import functools
  4import logging
  5import time
  6from collections.abc import Generator, Iterator, Mapping, Sequence
  7from contextlib import contextmanager
  8from hashlib import md5
  9from types import TracebackType
 10from typing import TYPE_CHECKING, Any, Self
 11
 12import psycopg
 13
 14from plain.models.db import NotSupportedError
 15from plain.models.otel import db_span
 16from plain.utils.dateparse import parse_time
 17
 18if TYPE_CHECKING:
 19    from plain.models.postgres.wrapper import DatabaseWrapper
 20
 21logger = logging.getLogger("plain.models.postgres")
 22
 23
 24class CursorWrapper:
 25    def __init__(self, cursor: Any, db: DatabaseWrapper) -> None:
 26        self.cursor = cursor
 27        self.db = db
 28
 29    WRAP_ERROR_ATTRS = frozenset(["nextset"])
 30
 31    def __getattr__(self, attr: str) -> Any:
 32        cursor_attr = getattr(self.cursor, attr)
 33        if attr in CursorWrapper.WRAP_ERROR_ATTRS:
 34            return self.db.wrap_database_errors(cursor_attr)
 35        else:
 36            return cursor_attr
 37
 38    def __iter__(self) -> Iterator[tuple[Any, ...]]:
 39        with self.db.wrap_database_errors:
 40            yield from self.cursor
 41
 42    def fetchone(self) -> tuple[Any, ...] | None:
 43        with self.db.wrap_database_errors:
 44            return self.cursor.fetchone()
 45
 46    def fetchmany(self, size: int | None = None) -> list[tuple[Any, ...]]:
 47        with self.db.wrap_database_errors:
 48            if size is None:
 49                return self.cursor.fetchmany()
 50            return self.cursor.fetchmany(size)
 51
 52    def fetchall(self) -> list[tuple[Any, ...]]:
 53        with self.db.wrap_database_errors:
 54            return self.cursor.fetchall()
 55
 56    def __enter__(self) -> Self:
 57        return self
 58
 59    def __exit__(
 60        self,
 61        type: type[BaseException] | None,
 62        value: BaseException | None,
 63        traceback: TracebackType | None,
 64    ) -> None:
 65        # Close instead of passing through to avoid backend-specific behavior
 66        # (#17671). Catch errors liberally because errors in cleanup code
 67        # aren't useful.
 68        try:
 69            self.close()
 70        except psycopg.Error:
 71            pass
 72
 73    # The following methods cannot be implemented in __getattr__, because the
 74    # code must run when the method is invoked, not just when it is accessed.
 75
 76    def callproc(
 77        self,
 78        procname: str,
 79        params: Sequence[Any] | None = None,
 80        kparams: Mapping[str, Any] | None = None,
 81    ) -> Any:
 82        # Keyword parameters for callproc aren't supported in PEP 249.
 83        # PostgreSQL's psycopg doesn't support them either.
 84        if kparams is not None:
 85            raise NotSupportedError(
 86                "Keyword parameters for callproc are not supported."
 87            )
 88        self.db.validate_no_broken_transaction()
 89        with self.db.wrap_database_errors:
 90            if params is None:
 91                return self.cursor.callproc(procname)
 92            return self.cursor.callproc(procname, params)
 93
 94    def execute(
 95        self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
 96    ) -> Self:
 97        return self._execute_with_wrappers(
 98            sql, params, many=False, executor=self._execute
 99        )
100
101    def executemany(self, sql: str, param_list: Sequence[Sequence[Any]]) -> Self:
102        return self._execute_with_wrappers(
103            sql, param_list, many=True, executor=self._executemany
104        )
105
106    def _execute_with_wrappers(
107        self, sql: str, params: Any, many: bool, executor: Any
108    ) -> Self:
109        context: dict[str, Any] = {"connection": self.db, "cursor": self}
110        for wrapper in reversed(self.db.execute_wrappers):
111            executor = functools.partial(wrapper, executor)
112        executor(sql, params, many, context)
113        return self
114
115    def _execute(self, sql: str, params: Any, *ignored_wrapper_args: Any) -> None:
116        # Wrap in an OpenTelemetry span with standard attributes.
117        with db_span(self.db, sql, params=params):
118            self.db.validate_no_broken_transaction()
119            with self.db.wrap_database_errors:
120                if params is None:
121                    self.cursor.execute(sql)
122                else:
123                    self.cursor.execute(sql, params)
124
125    def _executemany(
126        self, sql: str, param_list: Any, *ignored_wrapper_args: Any
127    ) -> None:
128        with db_span(self.db, sql, many=True, params=param_list):
129            self.db.validate_no_broken_transaction()
130            with self.db.wrap_database_errors:
131                self.cursor.executemany(sql, param_list)
132
133
134class CursorDebugWrapper(CursorWrapper):
135    # XXX callproc isn't instrumented at this time.
136
137    def execute(
138        self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
139    ) -> Self:
140        with self.debug_sql(sql, params, use_last_executed_query=True):
141            super().execute(sql, params)
142        return self
143
144    def executemany(self, sql: str, param_list: Sequence[Sequence[Any]]) -> Self:
145        with self.debug_sql(sql, param_list, many=True):
146            super().executemany(sql, param_list)
147        return self
148
149    @contextmanager
150    def debug_sql(
151        self,
152        sql: str | None = None,
153        params: Any = None,
154        use_last_executed_query: bool = False,
155        many: bool = False,
156    ) -> Generator[None, None, None]:
157        start = time.monotonic()
158        try:
159            yield
160        finally:
161            stop = time.monotonic()
162            duration = stop - start
163            if use_last_executed_query:
164                sql = self.db.last_executed_query(self.cursor, sql, params)  # type: ignore[arg-type]
165            try:
166                times = len(params) if many else ""
167            except TypeError:
168                # params could be an iterator.
169                times = "?"
170            self.db.queries_log.append(
171                {
172                    "sql": f"{times} times: {sql}" if many else sql,
173                    "time": f"{duration:.3f}",
174                }
175            )
176            logger.debug(
177                "(%.3f) %s; args=%s",
178                duration,
179                sql,
180                params,
181                extra={
182                    "duration": duration,
183                    "sql": sql,
184                    "params": params,
185                },
186            )
187
188
189@contextmanager
190def debug_transaction(
191    connection: DatabaseWrapper, sql: str
192) -> Generator[None, None, None]:
193    start = time.monotonic()
194    try:
195        yield
196    finally:
197        if connection.queries_logged:
198            stop = time.monotonic()
199            duration = stop - start
200            connection.queries_log.append(
201                {
202                    "sql": f"{sql}",
203                    "time": f"{duration:.3f}",
204                }
205            )
206            logger.debug(
207                "(%.3f) %s; args=%s",
208                duration,
209                sql,
210                None,
211                extra={
212                    "duration": duration,
213                    "sql": sql,
214                },
215            )
216
217
218def split_tzname_delta(tzname: str) -> tuple[str, str | None, str | None]:
219    """
220    Split a time zone name into a 3-tuple of (name, sign, offset).
221    """
222    for sign in ["+", "-"]:
223        if sign in tzname:
224            name, offset = tzname.rsplit(sign, 1)
225            if offset and parse_time(offset):
226                return name, sign, offset
227    return tzname, None, None
228
229
230###############################################
231# Converters from Python to database (string) #
232###############################################
233
234
235def split_identifier(identifier: str) -> tuple[str, str]:
236    """
237    Split an SQL identifier into a two element tuple of (namespace, name).
238
239    The identifier could be a table, column, or sequence name might be prefixed
240    by a namespace.
241    """
242    try:
243        namespace, name = identifier.split('"."')
244    except ValueError:
245        namespace, name = "", identifier
246    return namespace.strip('"'), name.strip('"')
247
248
249def truncate_name(identifier: str, length: int | None = None, hash_len: int = 4) -> str:
250    """
251    Shorten an SQL identifier to a repeatable mangled version with the given
252    length.
253
254    If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
255    truncate the table portion only.
256    """
257    namespace, name = split_identifier(identifier)
258
259    if length is None or len(name) <= length:
260        return identifier
261
262    digest = names_digest(name, length=hash_len)
263    return "{}{}{}".format(
264        f'{namespace}"."' if namespace else "",
265        name[: length - hash_len],
266        digest,
267    )
268
269
270def names_digest(*args: str, length: int) -> str:
271    """
272    Generate a 32-bit digest of a set of arguments that can be used to shorten
273    identifying names.
274    """
275    h = md5(usedforsecurity=False)
276    for arg in args:
277        h.update(arg.encode())
278    return h.hexdigest()[:length]
279
280
281def strip_quotes(table_name: str) -> str:
282    """
283    Strip quotes off of quoted table names to make them safe for use in index
284    names, sequence names, etc.
285    """
286    has_quotes = table_name.startswith('"') and table_name.endswith('"')
287    return table_name[1:-1] if has_quotes else table_name