Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import datetime
  4import decimal
  5import functools
  6import logging
  7import time
  8from collections.abc import Generator, Iterator
  9from contextlib import contextmanager
 10from hashlib import md5
 11from typing import TYPE_CHECKING, Any
 12
 13from plain.models.db import NotSupportedError
 14from plain.models.otel import db_span
 15from plain.utils.dateparse import parse_time
 16
 17if TYPE_CHECKING:
 18    from plain.models.backends.base.base import BaseDatabaseWrapper
 19
 20logger = logging.getLogger("plain.models.backends")
 21
 22
 23class CursorWrapper:
 24    def __init__(self, cursor: Any, db: Any) -> None:
 25        self.cursor = cursor
 26        self.db = db
 27
 28    WRAP_ERROR_ATTRS = frozenset(["fetchone", "fetchmany", "fetchall", "nextset"])
 29
 30    def __getattr__(self, attr: str) -> Any:
 31        cursor_attr = getattr(self.cursor, attr)
 32        if attr in CursorWrapper.WRAP_ERROR_ATTRS:
 33            return self.db.wrap_database_errors(cursor_attr)
 34        else:
 35            return cursor_attr
 36
 37    def __iter__(self) -> Iterator[Any]:
 38        with self.db.wrap_database_errors:
 39            yield from self.cursor
 40
 41    def __enter__(self) -> CursorWrapper:
 42        return self
 43
 44    def __exit__(self, type: Any, value: Any, traceback: Any) -> None:
 45        # Close instead of passing through to avoid backend-specific behavior
 46        # (#17671). Catch errors liberally because errors in cleanup code
 47        # aren't useful.
 48        try:
 49            self.close()  # type: ignore[attr-defined]
 50        except self.db.Database.Error:
 51            pass
 52
 53    # The following methods cannot be implemented in __getattr__, because the
 54    # code must run when the method is invoked, not just when it is accessed.
 55
 56    def callproc(self, procname: str, params: Any = None, kparams: Any = None) -> Any:
 57        # Keyword parameters for callproc aren't supported in PEP 249, but the
 58        # database driver may support them (e.g. cx_Oracle).
 59        if kparams is not None and not self.db.features.supports_callproc_kwargs:
 60            raise NotSupportedError(
 61                "Keyword parameters for callproc are not supported on this "
 62                "database backend."
 63            )
 64        self.db.validate_no_broken_transaction()
 65        with self.db.wrap_database_errors:
 66            if params is None and kparams is None:
 67                return self.cursor.callproc(procname)
 68            elif kparams is None:
 69                return self.cursor.callproc(procname, params)
 70            else:
 71                params = params or ()
 72                return self.cursor.callproc(procname, params, kparams)
 73
 74    def execute(self, sql: str, params: Any = None) -> Any:
 75        return self._execute_with_wrappers(
 76            sql, params, many=False, executor=self._execute
 77        )
 78
 79    def executemany(self, sql: str, param_list: Any) -> Any:
 80        return self._execute_with_wrappers(
 81            sql, param_list, many=True, executor=self._executemany
 82        )
 83
 84    def _execute_with_wrappers(
 85        self, sql: str, params: Any, many: bool, executor: Any
 86    ) -> Any:
 87        context: dict[str, Any] = {"connection": self.db, "cursor": self}
 88        for wrapper in reversed(self.db.execute_wrappers):
 89            executor = functools.partial(wrapper, executor)
 90        return executor(sql, params, many, context)
 91
 92    def _execute(self, sql: str, params: Any, *ignored_wrapper_args: Any) -> Any:
 93        # Wrap in an OpenTelemetry span with standard attributes.
 94        with db_span(self.db, sql, params=params):
 95            self.db.validate_no_broken_transaction()
 96            with self.db.wrap_database_errors:
 97                if params is None:
 98                    return self.cursor.execute(sql)
 99                else:
100                    return self.cursor.execute(sql, params)
101
102    def _executemany(
103        self, sql: str, param_list: Any, *ignored_wrapper_args: Any
104    ) -> Any:
105        with db_span(self.db, sql, many=True, params=param_list):
106            self.db.validate_no_broken_transaction()
107            with self.db.wrap_database_errors:
108                return self.cursor.executemany(sql, param_list)
109
110
111class CursorDebugWrapper(CursorWrapper):
112    # XXX callproc isn't instrumented at this time.
113
114    def execute(self, sql: str, params: Any = None) -> Any:
115        with self.debug_sql(sql, params, use_last_executed_query=True):
116            return super().execute(sql, params)
117
118    def executemany(self, sql: str, param_list: Any) -> Any:
119        with self.debug_sql(sql, param_list, many=True):
120            return super().executemany(sql, param_list)
121
122    @contextmanager
123    def debug_sql(
124        self,
125        sql: str | None = None,
126        params: Any = None,
127        use_last_executed_query: bool = False,
128        many: bool = False,
129    ) -> Generator[None, None, None]:
130        start = time.monotonic()
131        try:
132            yield
133        finally:
134            stop = time.monotonic()
135            duration = stop - start
136            if use_last_executed_query:
137                sql = self.db.ops.last_executed_query(self.cursor, sql, params)
138            try:
139                times = len(params) if many else ""  # type: ignore[arg-type]
140            except TypeError:
141                # params could be an iterator.
142                times = "?"
143            self.db.queries_log.append(
144                {
145                    "sql": f"{times} times: {sql}" if many else sql,
146                    "time": f"{duration:.3f}",
147                }
148            )
149            logger.debug(
150                "(%.3f) %s; args=%s",
151                duration,
152                sql,
153                params,
154                extra={
155                    "duration": duration,
156                    "sql": sql,
157                    "params": params,
158                },
159            )
160
161
162@contextmanager
163def debug_transaction(
164    connection: BaseDatabaseWrapper, sql: str
165) -> Generator[None, None, None]:
166    start = time.monotonic()
167    try:
168        yield
169    finally:
170        if connection.queries_logged:
171            stop = time.monotonic()
172            duration = stop - start
173            connection.queries_log.append(
174                {
175                    "sql": f"{sql}",
176                    "time": f"{duration:.3f}",
177                }
178            )
179            logger.debug(
180                "(%.3f) %s; args=%s",
181                duration,
182                sql,
183                None,
184                extra={
185                    "duration": duration,
186                    "sql": sql,
187                },
188            )
189
190
191def split_tzname_delta(tzname: str) -> tuple[str, str | None, str | None]:
192    """
193    Split a time zone name into a 3-tuple of (name, sign, offset).
194    """
195    for sign in ["+", "-"]:
196        if sign in tzname:
197            name, offset = tzname.rsplit(sign, 1)
198            if offset and parse_time(offset):
199                return name, sign, offset
200    return tzname, None, None
201
202
203###############################################
204# Converters from database (string) to Python #
205###############################################
206
207
208def typecast_date(s: str | None) -> datetime.date | None:
209    return (
210        datetime.date(*map(int, s.split("-"))) if s else None
211    )  # return None if s is null
212
213
214def typecast_time(
215    s: str | None,
216) -> datetime.time | None:  # does NOT store time zone information
217    if not s:
218        return None
219    hour, minutes, seconds = s.split(":")
220    if "." in seconds:  # check whether seconds have a fractional part
221        seconds, microseconds = seconds.split(".")
222    else:
223        microseconds = "0"
224    return datetime.time(
225        int(hour), int(minutes), int(seconds), int((microseconds + "000000")[:6])
226    )
227
228
229def typecast_timestamp(
230    s: str | None,
231) -> datetime.date | datetime.datetime | None:  # does NOT store time zone information
232    # "2005-07-29 15:48:00.590358-05"
233    # "2005-07-29 09:56:00-05"
234    if not s:
235        return None
236    if " " not in s:
237        return typecast_date(s)
238    d, t = s.split()
239    # Remove timezone information.
240    if "-" in t:
241        t, _ = t.split("-", 1)
242    elif "+" in t:
243        t, _ = t.split("+", 1)
244    dates = d.split("-")
245    times = t.split(":")
246    seconds = times[2]
247    if "." in seconds:  # check whether seconds have a fractional part
248        seconds, microseconds = seconds.split(".")
249    else:
250        microseconds = "0"
251    return datetime.datetime(
252        int(dates[0]),
253        int(dates[1]),
254        int(dates[2]),
255        int(times[0]),
256        int(times[1]),
257        int(seconds),
258        int((microseconds + "000000")[:6]),
259    )
260
261
262###############################################
263# Converters from Python to database (string) #
264###############################################
265
266
267def split_identifier(identifier: str) -> tuple[str, str]:
268    """
269    Split an SQL identifier into a two element tuple of (namespace, name).
270
271    The identifier could be a table, column, or sequence name might be prefixed
272    by a namespace.
273    """
274    try:
275        namespace, name = identifier.split('"."')
276    except ValueError:
277        namespace, name = "", identifier
278    return namespace.strip('"'), name.strip('"')
279
280
281def truncate_name(identifier: str, length: int | None = None, hash_len: int = 4) -> str:
282    """
283    Shorten an SQL identifier to a repeatable mangled version with the given
284    length.
285
286    If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
287    truncate the table portion only.
288    """
289    namespace, name = split_identifier(identifier)
290
291    if length is None or len(name) <= length:
292        return identifier
293
294    digest = names_digest(name, length=hash_len)
295    return "{}{}{}".format(
296        f'{namespace}"."' if namespace else "",
297        name[: length - hash_len],
298        digest,
299    )
300
301
302def names_digest(*args: str, length: int) -> str:
303    """
304    Generate a 32-bit digest of a set of arguments that can be used to shorten
305    identifying names.
306    """
307    h = md5(usedforsecurity=False)
308    for arg in args:
309        h.update(arg.encode())
310    return h.hexdigest()[:length]
311
312
313def format_number(
314    value: decimal.Decimal | None, max_digits: int | None, decimal_places: int | None
315) -> str | None:
316    """
317    Format a number into a string with the requisite number of digits and
318    decimal places.
319    """
320    if value is None:
321        return None
322    context = decimal.getcontext().copy()
323    if max_digits is not None:
324        context.prec = max_digits
325    if decimal_places is not None:
326        value = value.quantize(
327            decimal.Decimal(1).scaleb(-decimal_places), context=context
328        )
329    else:
330        context.traps[decimal.Rounded] = 1  # type: ignore[assignment]
331        value = context.create_decimal(value)
332    return f"{value:f}"
333
334
335def strip_quotes(table_name: str) -> str:
336    """
337    Strip quotes off of quoted table names to make them safe for use in index
338    names, sequence names, etc. For example '"USER"."TABLE"' (an Oracle naming
339    scheme) becomes 'USER"."TABLE'.
340    """
341    has_quotes = table_name.startswith('"') and table_name.endswith('"')
342    return table_name[1:-1] if has_quotes else table_name