Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import datetime
  4import decimal
  5import functools
  6import logging
  7import time
  8from collections.abc import Generator, Iterator, Mapping, Sequence
  9from contextlib import contextmanager
 10from hashlib import md5
 11from types import TracebackType
 12from typing import TYPE_CHECKING, Any, Protocol, Self
 13
 14from plain.models.db import NotSupportedError
 15from plain.models.otel import db_span
 16from plain.utils.dateparse import parse_time
 17
 18if TYPE_CHECKING:
 19    from plain.models.backends.base.base import BaseDatabaseWrapper
 20
 21logger = logging.getLogger("plain.models.backends")
 22
 23
 24class DBAPICursor(Protocol):
 25    """Protocol for DB-API 2.0 (PEP 249) cursor objects."""
 26
 27    @property
 28    def description(self) -> Sequence[Any] | None:
 29        """Column descriptions from the last query."""
 30        ...
 31
 32    @property
 33    def rowcount(self) -> int:
 34        """Number of rows affected by the last query."""
 35        ...
 36
 37    @property
 38    def lastrowid(self) -> int:
 39        """ID of the last inserted row (if applicable)."""
 40        ...
 41
 42    def close(self) -> None:
 43        """Close the cursor."""
 44        ...
 45
 46    def callproc(self, procname: str, *args: Any, **kwargs: Any) -> Any:
 47        """Call a stored database procedure (optional in DB-API 2.0)."""
 48        ...
 49
 50    def execute(
 51        self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
 52    ) -> Any:
 53        """Execute a database operation."""
 54        ...
 55
 56    def executemany(self, sql: str, params: Sequence[Sequence[Any]]) -> Any:
 57        """Execute a database operation multiple times."""
 58        ...
 59
 60    def fetchone(self) -> tuple[Any, ...] | None:
 61        """Fetch the next row of a query result set."""
 62        ...
 63
 64    def fetchmany(self, size: int = 0) -> list[tuple[Any, ...]]:
 65        """Fetch the next set of rows of a query result set."""
 66        ...
 67
 68    def fetchall(self) -> list[tuple[Any, ...]]:
 69        """Fetch all remaining rows of a query result set."""
 70        ...
 71
 72    def __iter__(self) -> Iterator[tuple[Any, ...]]:
 73        """Iterate over rows in the result set."""
 74        ...
 75
 76
 77class CursorWrapper:
 78    def __init__(self, cursor: DBAPICursor, db: BaseDatabaseWrapper) -> None:
 79        self.cursor = cursor
 80        self.db = db
 81
 82    WRAP_ERROR_ATTRS = frozenset(["fetchone", "fetchmany", "fetchall", "nextset"])
 83
 84    def __getattr__(self, attr: str) -> Any:
 85        cursor_attr = getattr(self.cursor, attr)
 86        if attr in CursorWrapper.WRAP_ERROR_ATTRS:
 87            return self.db.wrap_database_errors(cursor_attr)
 88        else:
 89            return cursor_attr
 90
 91    def __iter__(self) -> Iterator[tuple[Any, ...]]:
 92        with self.db.wrap_database_errors:
 93            yield from self.cursor
 94
 95    def __enter__(self) -> Self:
 96        return self
 97
 98    def __exit__(
 99        self,
100        type: type[BaseException] | None,
101        value: BaseException | None,
102        traceback: TracebackType | None,
103    ) -> None:
104        # Close instead of passing through to avoid backend-specific behavior
105        # (#17671). Catch errors liberally because errors in cleanup code
106        # aren't useful.
107        try:
108            self.close()
109        except self.db.Database.Error:
110            pass
111
112    # The following methods cannot be implemented in __getattr__, because the
113    # code must run when the method is invoked, not just when it is accessed.
114
115    def callproc(
116        self,
117        procname: str,
118        params: Sequence[Any] | None = None,
119        kparams: Mapping[str, Any] | None = None,
120    ) -> Any:
121        # Keyword parameters for callproc aren't supported in PEP 249, but the
122        # database driver may support them (e.g. cx_Oracle).
123        if kparams is not None and not self.db.features.supports_callproc_kwargs:
124            raise NotSupportedError(
125                "Keyword parameters for callproc are not supported on this "
126                "database backend."
127            )
128        self.db.validate_no_broken_transaction()
129        with self.db.wrap_database_errors:
130            if params is None and kparams is None:
131                return self.cursor.callproc(procname)
132            elif kparams is None:
133                return self.cursor.callproc(procname, params)
134            else:
135                params = params or ()
136                return self.cursor.callproc(procname, params, kparams)
137
138    def execute(
139        self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
140    ) -> Self:
141        return self._execute_with_wrappers(
142            sql, params, many=False, executor=self._execute
143        )
144
145    def executemany(self, sql: str, param_list: Sequence[Sequence[Any]]) -> Self:
146        return self._execute_with_wrappers(
147            sql, param_list, many=True, executor=self._executemany
148        )
149
150    def _execute_with_wrappers(
151        self, sql: str, params: Any, many: bool, executor: Any
152    ) -> Self:
153        context: dict[str, Any] = {"connection": self.db, "cursor": self}
154        for wrapper in reversed(self.db.execute_wrappers):
155            executor = functools.partial(wrapper, executor)
156        executor(sql, params, many, context)
157        return self
158
159    def _execute(self, sql: str, params: Any, *ignored_wrapper_args: Any) -> None:
160        # Wrap in an OpenTelemetry span with standard attributes.
161        with db_span(self.db, sql, params=params):
162            self.db.validate_no_broken_transaction()
163            with self.db.wrap_database_errors:
164                if params is None:
165                    self.cursor.execute(sql)
166                else:
167                    self.cursor.execute(sql, params)
168
169    def _executemany(
170        self, sql: str, param_list: Any, *ignored_wrapper_args: Any
171    ) -> None:
172        with db_span(self.db, sql, many=True, params=param_list):
173            self.db.validate_no_broken_transaction()
174            with self.db.wrap_database_errors:
175                self.cursor.executemany(sql, param_list)
176
177
178class CursorDebugWrapper(CursorWrapper):
179    # XXX callproc isn't instrumented at this time.
180
181    def execute(
182        self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
183    ) -> Self:
184        with self.debug_sql(sql, params, use_last_executed_query=True):
185            super().execute(sql, params)
186        return self
187
188    def executemany(self, sql: str, param_list: Sequence[Sequence[Any]]) -> Self:
189        with self.debug_sql(sql, param_list, many=True):
190            super().executemany(sql, param_list)
191        return self
192
193    @contextmanager
194    def debug_sql(
195        self,
196        sql: str | None = None,
197        params: Any = None,
198        use_last_executed_query: bool = False,
199        many: bool = False,
200    ) -> Generator[None, None, None]:
201        start = time.monotonic()
202        try:
203            yield
204        finally:
205            stop = time.monotonic()
206            duration = stop - start
207            if use_last_executed_query:
208                sql = self.db.ops.last_executed_query(self.cursor, sql, params)  # type: ignore[arg-type]
209            try:
210                times = len(params) if many else ""
211            except TypeError:
212                # params could be an iterator.
213                times = "?"
214            self.db.queries_log.append(
215                {
216                    "sql": f"{times} times: {sql}" if many else sql,
217                    "time": f"{duration:.3f}",
218                }
219            )
220            logger.debug(
221                "(%.3f) %s; args=%s",
222                duration,
223                sql,
224                params,
225                extra={
226                    "duration": duration,
227                    "sql": sql,
228                    "params": params,
229                },
230            )
231
232
233@contextmanager
234def debug_transaction(
235    connection: BaseDatabaseWrapper, sql: str
236) -> Generator[None, None, None]:
237    start = time.monotonic()
238    try:
239        yield
240    finally:
241        if connection.queries_logged:
242            stop = time.monotonic()
243            duration = stop - start
244            connection.queries_log.append(
245                {
246                    "sql": f"{sql}",
247                    "time": f"{duration:.3f}",
248                }
249            )
250            logger.debug(
251                "(%.3f) %s; args=%s",
252                duration,
253                sql,
254                None,
255                extra={
256                    "duration": duration,
257                    "sql": sql,
258                },
259            )
260
261
262def split_tzname_delta(tzname: str) -> tuple[str, str | None, str | None]:
263    """
264    Split a time zone name into a 3-tuple of (name, sign, offset).
265    """
266    for sign in ["+", "-"]:
267        if sign in tzname:
268            name, offset = tzname.rsplit(sign, 1)
269            if offset and parse_time(offset):
270                return name, sign, offset
271    return tzname, None, None
272
273
274###############################################
275# Converters from database (string) to Python #
276###############################################
277
278
279def typecast_date(s: str | None) -> datetime.date | None:
280    return (
281        datetime.date(*map(int, s.split("-"))) if s else None
282    )  # return None if s is null
283
284
285def typecast_time(
286    s: str | None,
287) -> datetime.time | None:  # does NOT store time zone information
288    if not s:
289        return None
290    hour, minutes, seconds = s.split(":")
291    if "." in seconds:  # check whether seconds have a fractional part
292        seconds, microseconds = seconds.split(".")
293    else:
294        microseconds = "0"
295    return datetime.time(
296        int(hour), int(minutes), int(seconds), int((microseconds + "000000")[:6])
297    )
298
299
300def typecast_timestamp(
301    s: str | None,
302) -> datetime.date | datetime.datetime | None:  # does NOT store time zone information
303    # "2005-07-29 15:48:00.590358-05"
304    # "2005-07-29 09:56:00-05"
305    if not s:
306        return None
307    if " " not in s:
308        return typecast_date(s)
309    d, t = s.split()
310    # Remove timezone information.
311    if "-" in t:
312        t, _ = t.split("-", 1)
313    elif "+" in t:
314        t, _ = t.split("+", 1)
315    dates = d.split("-")
316    times = t.split(":")
317    seconds = times[2]
318    if "." in seconds:  # check whether seconds have a fractional part
319        seconds, microseconds = seconds.split(".")
320    else:
321        microseconds = "0"
322    return datetime.datetime(
323        int(dates[0]),
324        int(dates[1]),
325        int(dates[2]),
326        int(times[0]),
327        int(times[1]),
328        int(seconds),
329        int((microseconds + "000000")[:6]),
330    )
331
332
333###############################################
334# Converters from Python to database (string) #
335###############################################
336
337
338def split_identifier(identifier: str) -> tuple[str, str]:
339    """
340    Split an SQL identifier into a two element tuple of (namespace, name).
341
342    The identifier could be a table, column, or sequence name might be prefixed
343    by a namespace.
344    """
345    try:
346        namespace, name = identifier.split('"."')
347    except ValueError:
348        namespace, name = "", identifier
349    return namespace.strip('"'), name.strip('"')
350
351
352def truncate_name(identifier: str, length: int | None = None, hash_len: int = 4) -> str:
353    """
354    Shorten an SQL identifier to a repeatable mangled version with the given
355    length.
356
357    If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
358    truncate the table portion only.
359    """
360    namespace, name = split_identifier(identifier)
361
362    if length is None or len(name) <= length:
363        return identifier
364
365    digest = names_digest(name, length=hash_len)
366    return "{}{}{}".format(
367        f'{namespace}"."' if namespace else "",
368        name[: length - hash_len],
369        digest,
370    )
371
372
373def names_digest(*args: str, length: int) -> str:
374    """
375    Generate a 32-bit digest of a set of arguments that can be used to shorten
376    identifying names.
377    """
378    h = md5(usedforsecurity=False)
379    for arg in args:
380        h.update(arg.encode())
381    return h.hexdigest()[:length]
382
383
384def format_number(
385    value: decimal.Decimal | None, max_digits: int | None, decimal_places: int | None
386) -> str | None:
387    """
388    Format a number into a string with the requisite number of digits and
389    decimal places.
390    """
391    if value is None:
392        return None
393    context = decimal.getcontext().copy()
394    if max_digits is not None:
395        context.prec = max_digits
396    if decimal_places is not None:
397        value = value.quantize(
398            decimal.Decimal(1).scaleb(-decimal_places), context=context
399        )
400    else:
401        context.traps[decimal.Rounded] = True
402        value = context.create_decimal(value)
403    return f"{value:f}"
404
405
406def strip_quotes(table_name: str) -> str:
407    """
408    Strip quotes off of quoted table names to make them safe for use in index
409    names, sequence names, etc. For example '"USER"."TABLE"' (an Oracle naming
410    scheme) becomes 'USER"."TABLE'.
411    """
412    has_quotes = table_name.startswith('"') and table_name.endswith('"')
413    return table_name[1:-1] if has_quotes else table_name