1from __future__ import annotations
  2
  3import datetime
  4import decimal
  5import functools
  6import logging
  7import time
  8from collections.abc import Generator, Iterator, Mapping, Sequence
  9from contextlib import contextmanager
 10from hashlib import md5
 11from types import TracebackType
 12from typing import TYPE_CHECKING, Any, Protocol, Self
 13
 14from plain.models.db import NotSupportedError
 15from plain.models.otel import db_span
 16from plain.utils.dateparse import parse_time
 17
 18if TYPE_CHECKING:
 19    from plain.models.backends.base.base import BaseDatabaseWrapper
 20
 21logger = logging.getLogger("plain.models.backends")
 22
 23
 24class DBAPICursor(Protocol):
 25    """Protocol for DB-API 2.0 (PEP 249) cursor objects."""
 26
 27    @property
 28    def description(self) -> Sequence[Any] | None:
 29        """Column descriptions from the last query."""
 30        ...
 31
 32    @property
 33    def rowcount(self) -> int:
 34        """Number of rows affected by the last query."""
 35        ...
 36
 37    @property
 38    def lastrowid(self) -> int:
 39        """ID of the last inserted row (if applicable)."""
 40        ...
 41
 42    def close(self) -> None:
 43        """Close the cursor."""
 44        ...
 45
 46    def callproc(self, procname: str, *args: Any, **kwargs: Any) -> Any:
 47        """Call a stored database procedure (optional in DB-API 2.0)."""
 48        ...
 49
 50    def execute(
 51        self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
 52    ) -> Any:
 53        """Execute a database operation."""
 54        ...
 55
 56    def executemany(self, sql: str, params: Sequence[Sequence[Any]]) -> Any:
 57        """Execute a database operation multiple times."""
 58        ...
 59
 60    def fetchone(self) -> tuple[Any, ...] | None:
 61        """Fetch the next row of a query result set."""
 62        ...
 63
 64    def fetchmany(self, size: int | None = None) -> list[tuple[Any, ...]]:
 65        """Fetch the next set of rows of a query result set."""
 66        ...
 67
 68    def fetchall(self) -> list[tuple[Any, ...]]:
 69        """Fetch all remaining rows of a query result set."""
 70        ...
 71
 72    def __iter__(self) -> Iterator[tuple[Any, ...]]:
 73        """Iterate over rows in the result set."""
 74        ...
 75
 76
 77class CursorWrapper:
 78    def __init__(self, cursor: DBAPICursor, db: BaseDatabaseWrapper) -> None:
 79        self.cursor = cursor
 80        self.db = db
 81
 82    WRAP_ERROR_ATTRS = frozenset(["nextset"])
 83
 84    def __getattr__(self, attr: str) -> Any:
 85        cursor_attr = getattr(self.cursor, attr)
 86        if attr in CursorWrapper.WRAP_ERROR_ATTRS:
 87            return self.db.wrap_database_errors(cursor_attr)
 88        else:
 89            return cursor_attr
 90
 91    def __iter__(self) -> Iterator[tuple[Any, ...]]:
 92        with self.db.wrap_database_errors:
 93            yield from self.cursor
 94
 95    def fetchone(self) -> tuple[Any, ...] | None:
 96        with self.db.wrap_database_errors:
 97            return self.cursor.fetchone()
 98
 99    def fetchmany(self, size: int | None = None) -> list[tuple[Any, ...]]:
100        with self.db.wrap_database_errors:
101            if size is None:
102                return self.cursor.fetchmany()
103            return self.cursor.fetchmany(size)
104
105    def fetchall(self) -> list[tuple[Any, ...]]:
106        with self.db.wrap_database_errors:
107            return self.cursor.fetchall()
108
109    def __enter__(self) -> Self:
110        return self
111
112    def __exit__(
113        self,
114        type: type[BaseException] | None,
115        value: BaseException | None,
116        traceback: TracebackType | None,
117    ) -> None:
118        # Close instead of passing through to avoid backend-specific behavior
119        # (#17671). Catch errors liberally because errors in cleanup code
120        # aren't useful.
121        try:
122            self.close()
123        except self.db.Database.Error:
124            pass
125
126    # The following methods cannot be implemented in __getattr__, because the
127    # code must run when the method is invoked, not just when it is accessed.
128
129    def callproc(
130        self,
131        procname: str,
132        params: Sequence[Any] | None = None,
133        kparams: Mapping[str, Any] | None = None,
134    ) -> Any:
135        # Keyword parameters for callproc aren't supported in PEP 249, but the
136        # database driver may support them (e.g. cx_Oracle).
137        if kparams is not None and not self.db.features.supports_callproc_kwargs:
138            raise NotSupportedError(
139                "Keyword parameters for callproc are not supported on this "
140                "database backend."
141            )
142        self.db.validate_no_broken_transaction()
143        with self.db.wrap_database_errors:
144            if params is None and kparams is None:
145                return self.cursor.callproc(procname)
146            elif kparams is None:
147                return self.cursor.callproc(procname, params)
148            else:
149                params = params or ()
150                return self.cursor.callproc(procname, params, kparams)
151
152    def execute(
153        self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
154    ) -> Self:
155        return self._execute_with_wrappers(
156            sql, params, many=False, executor=self._execute
157        )
158
159    def executemany(self, sql: str, param_list: Sequence[Sequence[Any]]) -> Self:
160        return self._execute_with_wrappers(
161            sql, param_list, many=True, executor=self._executemany
162        )
163
164    def _execute_with_wrappers(
165        self, sql: str, params: Any, many: bool, executor: Any
166    ) -> Self:
167        context: dict[str, Any] = {"connection": self.db, "cursor": self}
168        for wrapper in reversed(self.db.execute_wrappers):
169            executor = functools.partial(wrapper, executor)
170        executor(sql, params, many, context)
171        return self
172
173    def _execute(self, sql: str, params: Any, *ignored_wrapper_args: Any) -> None:
174        # Wrap in an OpenTelemetry span with standard attributes.
175        with db_span(self.db, sql, params=params):
176            self.db.validate_no_broken_transaction()
177            with self.db.wrap_database_errors:
178                if params is None:
179                    self.cursor.execute(sql)
180                else:
181                    self.cursor.execute(sql, params)
182
183    def _executemany(
184        self, sql: str, param_list: Any, *ignored_wrapper_args: Any
185    ) -> None:
186        with db_span(self.db, sql, many=True, params=param_list):
187            self.db.validate_no_broken_transaction()
188            with self.db.wrap_database_errors:
189                self.cursor.executemany(sql, param_list)
190
191
192class CursorDebugWrapper(CursorWrapper):
193    # XXX callproc isn't instrumented at this time.
194
195    def execute(
196        self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
197    ) -> Self:
198        with self.debug_sql(sql, params, use_last_executed_query=True):
199            super().execute(sql, params)
200        return self
201
202    def executemany(self, sql: str, param_list: Sequence[Sequence[Any]]) -> Self:
203        with self.debug_sql(sql, param_list, many=True):
204            super().executemany(sql, param_list)
205        return self
206
207    @contextmanager
208    def debug_sql(
209        self,
210        sql: str | None = None,
211        params: Any = None,
212        use_last_executed_query: bool = False,
213        many: bool = False,
214    ) -> Generator[None, None, None]:
215        start = time.monotonic()
216        try:
217            yield
218        finally:
219            stop = time.monotonic()
220            duration = stop - start
221            if use_last_executed_query:
222                sql = self.db.ops.last_executed_query(self.cursor, sql, params)  # type: ignore[arg-type]
223            try:
224                times = len(params) if many else ""
225            except TypeError:
226                # params could be an iterator.
227                times = "?"
228            self.db.queries_log.append(
229                {
230                    "sql": f"{times} times: {sql}" if many else sql,
231                    "time": f"{duration:.3f}",
232                }
233            )
234            logger.debug(
235                "(%.3f) %s; args=%s",
236                duration,
237                sql,
238                params,
239                extra={
240                    "duration": duration,
241                    "sql": sql,
242                    "params": params,
243                },
244            )
245
246
247@contextmanager
248def debug_transaction(
249    connection: BaseDatabaseWrapper, sql: str
250) -> Generator[None, None, None]:
251    start = time.monotonic()
252    try:
253        yield
254    finally:
255        if connection.queries_logged:
256            stop = time.monotonic()
257            duration = stop - start
258            connection.queries_log.append(
259                {
260                    "sql": f"{sql}",
261                    "time": f"{duration:.3f}",
262                }
263            )
264            logger.debug(
265                "(%.3f) %s; args=%s",
266                duration,
267                sql,
268                None,
269                extra={
270                    "duration": duration,
271                    "sql": sql,
272                },
273            )
274
275
276def split_tzname_delta(tzname: str) -> tuple[str, str | None, str | None]:
277    """
278    Split a time zone name into a 3-tuple of (name, sign, offset).
279    """
280    for sign in ["+", "-"]:
281        if sign in tzname:
282            name, offset = tzname.rsplit(sign, 1)
283            if offset and parse_time(offset):
284                return name, sign, offset
285    return tzname, None, None
286
287
288###############################################
289# Converters from database (string) to Python #
290###############################################
291
292
293def typecast_date(s: str | None) -> datetime.date | None:
294    return (
295        datetime.date(*map(int, s.split("-"))) if s else None
296    )  # return None if s is null
297
298
299def typecast_time(
300    s: str | None,
301) -> datetime.time | None:  # does NOT store time zone information
302    if not s:
303        return None
304    hour, minutes, seconds = s.split(":")
305    if "." in seconds:  # check whether seconds have a fractional part
306        seconds, microseconds = seconds.split(".")
307    else:
308        microseconds = "0"
309    return datetime.time(
310        int(hour), int(minutes), int(seconds), int((microseconds + "000000")[:6])
311    )
312
313
314def typecast_timestamp(
315    s: str | None,
316) -> datetime.date | datetime.datetime | None:  # does NOT store time zone information
317    # "2005-07-29 15:48:00.590358-05"
318    # "2005-07-29 09:56:00-05"
319    if not s:
320        return None
321    if " " not in s:
322        return typecast_date(s)
323    d, t = s.split()
324    # Remove timezone information.
325    if "-" in t:
326        t, _ = t.split("-", 1)
327    elif "+" in t:
328        t, _ = t.split("+", 1)
329    dates = d.split("-")
330    times = t.split(":")
331    seconds = times[2]
332    if "." in seconds:  # check whether seconds have a fractional part
333        seconds, microseconds = seconds.split(".")
334    else:
335        microseconds = "0"
336    return datetime.datetime(
337        int(dates[0]),
338        int(dates[1]),
339        int(dates[2]),
340        int(times[0]),
341        int(times[1]),
342        int(seconds),
343        int((microseconds + "000000")[:6]),
344    )
345
346
347###############################################
348# Converters from Python to database (string) #
349###############################################
350
351
352def split_identifier(identifier: str) -> tuple[str, str]:
353    """
354    Split an SQL identifier into a two element tuple of (namespace, name).
355
356    The identifier could be a table, column, or sequence name might be prefixed
357    by a namespace.
358    """
359    try:
360        namespace, name = identifier.split('"."')
361    except ValueError:
362        namespace, name = "", identifier
363    return namespace.strip('"'), name.strip('"')
364
365
366def truncate_name(identifier: str, length: int | None = None, hash_len: int = 4) -> str:
367    """
368    Shorten an SQL identifier to a repeatable mangled version with the given
369    length.
370
371    If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
372    truncate the table portion only.
373    """
374    namespace, name = split_identifier(identifier)
375
376    if length is None or len(name) <= length:
377        return identifier
378
379    digest = names_digest(name, length=hash_len)
380    return "{}{}{}".format(
381        f'{namespace}"."' if namespace else "",
382        name[: length - hash_len],
383        digest,
384    )
385
386
387def names_digest(*args: str, length: int) -> str:
388    """
389    Generate a 32-bit digest of a set of arguments that can be used to shorten
390    identifying names.
391    """
392    h = md5(usedforsecurity=False)
393    for arg in args:
394        h.update(arg.encode())
395    return h.hexdigest()[:length]
396
397
398def format_number(
399    value: decimal.Decimal | None, max_digits: int | None, decimal_places: int | None
400) -> str | None:
401    """
402    Format a number into a string with the requisite number of digits and
403    decimal places.
404    """
405    if value is None:
406        return None
407    context = decimal.getcontext().copy()
408    if max_digits is not None:
409        context.prec = max_digits
410    if decimal_places is not None:
411        value = value.quantize(
412            decimal.Decimal(1).scaleb(-decimal_places), context=context
413        )
414    else:
415        context.traps[decimal.Rounded] = True
416        value = context.create_decimal(value)
417    return f"{value:f}"
418
419
420def strip_quotes(table_name: str) -> str:
421    """
422    Strip quotes off of quoted table names to make them safe for use in index
423    names, sequence names, etc. For example '"USER"."TABLE"' (an Oracle naming
424    scheme) becomes 'USER"."TABLE'.
425    """
426    has_quotes = table_name.startswith('"') and table_name.endswith('"')
427    return table_name[1:-1] if has_quotes else table_name