Plain is headed towards 1.0! Subscribe for development updates →

  1import datetime
  2import decimal
  3import functools
  4import logging
  5import time
  6from contextlib import contextmanager
  7from hashlib import md5
  8
  9from plain.models.db import NotSupportedError
 10from plain.models.otel import db_span
 11from plain.utils.dateparse import parse_time
 12
 13logger = logging.getLogger("plain.models.backends")
 14
 15
 16class CursorWrapper:
 17    def __init__(self, cursor, db):
 18        self.cursor = cursor
 19        self.db = db
 20
 21    WRAP_ERROR_ATTRS = frozenset(["fetchone", "fetchmany", "fetchall", "nextset"])
 22
 23    def __getattr__(self, attr):
 24        cursor_attr = getattr(self.cursor, attr)
 25        if attr in CursorWrapper.WRAP_ERROR_ATTRS:
 26            return self.db.wrap_database_errors(cursor_attr)
 27        else:
 28            return cursor_attr
 29
 30    def __iter__(self):
 31        with self.db.wrap_database_errors:
 32            yield from self.cursor
 33
 34    def __enter__(self):
 35        return self
 36
 37    def __exit__(self, type, value, traceback):
 38        # Close instead of passing through to avoid backend-specific behavior
 39        # (#17671). Catch errors liberally because errors in cleanup code
 40        # aren't useful.
 41        try:
 42            self.close()
 43        except self.db.Database.Error:
 44            pass
 45
 46    # The following methods cannot be implemented in __getattr__, because the
 47    # code must run when the method is invoked, not just when it is accessed.
 48
 49    def callproc(self, procname, params=None, kparams=None):
 50        # Keyword parameters for callproc aren't supported in PEP 249, but the
 51        # database driver may support them (e.g. cx_Oracle).
 52        if kparams is not None and not self.db.features.supports_callproc_kwargs:
 53            raise NotSupportedError(
 54                "Keyword parameters for callproc are not supported on this "
 55                "database backend."
 56            )
 57        self.db.validate_no_broken_transaction()
 58        with self.db.wrap_database_errors:
 59            if params is None and kparams is None:
 60                return self.cursor.callproc(procname)
 61            elif kparams is None:
 62                return self.cursor.callproc(procname, params)
 63            else:
 64                params = params or ()
 65                return self.cursor.callproc(procname, params, kparams)
 66
 67    def execute(self, sql, params=None):
 68        return self._execute_with_wrappers(
 69            sql, params, many=False, executor=self._execute
 70        )
 71
 72    def executemany(self, sql, param_list):
 73        return self._execute_with_wrappers(
 74            sql, param_list, many=True, executor=self._executemany
 75        )
 76
 77    def _execute_with_wrappers(self, sql, params, many, executor):
 78        context = {"connection": self.db, "cursor": self}
 79        for wrapper in reversed(self.db.execute_wrappers):
 80            executor = functools.partial(wrapper, executor)
 81        return executor(sql, params, many, context)
 82
 83    def _execute(self, sql, params, *ignored_wrapper_args):
 84        # Wrap in an OpenTelemetry span with standard attributes.
 85        with db_span(self.db, sql, params=params):
 86            self.db.validate_no_broken_transaction()
 87            with self.db.wrap_database_errors:
 88                if params is None:
 89                    return self.cursor.execute(sql)
 90                else:
 91                    return self.cursor.execute(sql, params)
 92
 93    def _executemany(self, sql, param_list, *ignored_wrapper_args):
 94        with db_span(self.db, sql, many=True, params=param_list):
 95            self.db.validate_no_broken_transaction()
 96            with self.db.wrap_database_errors:
 97                return self.cursor.executemany(sql, param_list)
 98
 99
100class CursorDebugWrapper(CursorWrapper):
101    # XXX callproc isn't instrumented at this time.
102
103    def execute(self, sql, params=None):
104        with self.debug_sql(sql, params, use_last_executed_query=True):
105            return super().execute(sql, params)
106
107    def executemany(self, sql, param_list):
108        with self.debug_sql(sql, param_list, many=True):
109            return super().executemany(sql, param_list)
110
111    @contextmanager
112    def debug_sql(
113        self, sql=None, params=None, use_last_executed_query=False, many=False
114    ):
115        start = time.monotonic()
116        try:
117            yield
118        finally:
119            stop = time.monotonic()
120            duration = stop - start
121            if use_last_executed_query:
122                sql = self.db.ops.last_executed_query(self.cursor, sql, params)
123            try:
124                times = len(params) if many else ""
125            except TypeError:
126                # params could be an iterator.
127                times = "?"
128            self.db.queries_log.append(
129                {
130                    "sql": f"{times} times: {sql}" if many else sql,
131                    "time": f"{duration:.3f}",
132                }
133            )
134            logger.debug(
135                "(%.3f) %s; args=%s",
136                duration,
137                sql,
138                params,
139                extra={
140                    "duration": duration,
141                    "sql": sql,
142                    "params": params,
143                },
144            )
145
146
147@contextmanager
148def debug_transaction(connection, sql):
149    start = time.monotonic()
150    try:
151        yield
152    finally:
153        if connection.queries_logged:
154            stop = time.monotonic()
155            duration = stop - start
156            connection.queries_log.append(
157                {
158                    "sql": f"{sql}",
159                    "time": f"{duration:.3f}",
160                }
161            )
162            logger.debug(
163                "(%.3f) %s; args=%s",
164                duration,
165                sql,
166                None,
167                extra={
168                    "duration": duration,
169                    "sql": sql,
170                },
171            )
172
173
174def split_tzname_delta(tzname):
175    """
176    Split a time zone name into a 3-tuple of (name, sign, offset).
177    """
178    for sign in ["+", "-"]:
179        if sign in tzname:
180            name, offset = tzname.rsplit(sign, 1)
181            if offset and parse_time(offset):
182                return name, sign, offset
183    return tzname, None, None
184
185
186###############################################
187# Converters from database (string) to Python #
188###############################################
189
190
191def typecast_date(s):
192    return (
193        datetime.date(*map(int, s.split("-"))) if s else None
194    )  # return None if s is null
195
196
197def typecast_time(s):  # does NOT store time zone information
198    if not s:
199        return None
200    hour, minutes, seconds = s.split(":")
201    if "." in seconds:  # check whether seconds have a fractional part
202        seconds, microseconds = seconds.split(".")
203    else:
204        microseconds = "0"
205    return datetime.time(
206        int(hour), int(minutes), int(seconds), int((microseconds + "000000")[:6])
207    )
208
209
210def typecast_timestamp(s):  # does NOT store time zone information
211    # "2005-07-29 15:48:00.590358-05"
212    # "2005-07-29 09:56:00-05"
213    if not s:
214        return None
215    if " " not in s:
216        return typecast_date(s)
217    d, t = s.split()
218    # Remove timezone information.
219    if "-" in t:
220        t, _ = t.split("-", 1)
221    elif "+" in t:
222        t, _ = t.split("+", 1)
223    dates = d.split("-")
224    times = t.split(":")
225    seconds = times[2]
226    if "." in seconds:  # check whether seconds have a fractional part
227        seconds, microseconds = seconds.split(".")
228    else:
229        microseconds = "0"
230    return datetime.datetime(
231        int(dates[0]),
232        int(dates[1]),
233        int(dates[2]),
234        int(times[0]),
235        int(times[1]),
236        int(seconds),
237        int((microseconds + "000000")[:6]),
238    )
239
240
241###############################################
242# Converters from Python to database (string) #
243###############################################
244
245
246def split_identifier(identifier):
247    """
248    Split an SQL identifier into a two element tuple of (namespace, name).
249
250    The identifier could be a table, column, or sequence name might be prefixed
251    by a namespace.
252    """
253    try:
254        namespace, name = identifier.split('"."')
255    except ValueError:
256        namespace, name = "", identifier
257    return namespace.strip('"'), name.strip('"')
258
259
260def truncate_name(identifier, length=None, hash_len=4):
261    """
262    Shorten an SQL identifier to a repeatable mangled version with the given
263    length.
264
265    If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
266    truncate the table portion only.
267    """
268    namespace, name = split_identifier(identifier)
269
270    if length is None or len(name) <= length:
271        return identifier
272
273    digest = names_digest(name, length=hash_len)
274    return "{}{}{}".format(
275        f'{namespace}"."' if namespace else "",
276        name[: length - hash_len],
277        digest,
278    )
279
280
281def names_digest(*args, length):
282    """
283    Generate a 32-bit digest of a set of arguments that can be used to shorten
284    identifying names.
285    """
286    h = md5(usedforsecurity=False)
287    for arg in args:
288        h.update(arg.encode())
289    return h.hexdigest()[:length]
290
291
292def format_number(value, max_digits, decimal_places):
293    """
294    Format a number into a string with the requisite number of digits and
295    decimal places.
296    """
297    if value is None:
298        return None
299    context = decimal.getcontext().copy()
300    if max_digits is not None:
301        context.prec = max_digits
302    if decimal_places is not None:
303        value = value.quantize(
304            decimal.Decimal(1).scaleb(-decimal_places), context=context
305        )
306    else:
307        context.traps[decimal.Rounded] = 1
308        value = context.create_decimal(value)
309    return f"{value:f}"
310
311
312def strip_quotes(table_name):
313    """
314    Strip quotes off of quoted table names to make them safe for use in index
315    names, sequence names, etc. For example '"USER"."TABLE"' (an Oracle naming
316    scheme) becomes 'USER"."TABLE'.
317    """
318    has_quotes = table_name.startswith('"') and table_name.endswith('"')
319    return table_name[1:-1] if has_quotes else table_name