Plain is headed towards 1.0! Subscribe for development updates →

  1import datetime
  2import decimal
  3import functools
  4import logging
  5import time
  6from contextlib import contextmanager
  7from hashlib import md5
  8
  9from plain.models.db import NotSupportedError
 10from plain.utils.dateparse import parse_time
 11
 12logger = logging.getLogger("plain.models.backends")
 13
 14
 15class CursorWrapper:
 16    def __init__(self, cursor, db):
 17        self.cursor = cursor
 18        self.db = db
 19
 20    WRAP_ERROR_ATTRS = frozenset(["fetchone", "fetchmany", "fetchall", "nextset"])
 21
 22    def __getattr__(self, attr):
 23        cursor_attr = getattr(self.cursor, attr)
 24        if attr in CursorWrapper.WRAP_ERROR_ATTRS:
 25            return self.db.wrap_database_errors(cursor_attr)
 26        else:
 27            return cursor_attr
 28
 29    def __iter__(self):
 30        with self.db.wrap_database_errors:
 31            yield from self.cursor
 32
 33    def __enter__(self):
 34        return self
 35
 36    def __exit__(self, type, value, traceback):
 37        # Close instead of passing through to avoid backend-specific behavior
 38        # (#17671). Catch errors liberally because errors in cleanup code
 39        # aren't useful.
 40        try:
 41            self.close()
 42        except self.db.Database.Error:
 43            pass
 44
 45    # The following methods cannot be implemented in __getattr__, because the
 46    # code must run when the method is invoked, not just when it is accessed.
 47
 48    def callproc(self, procname, params=None, kparams=None):
 49        # Keyword parameters for callproc aren't supported in PEP 249, but the
 50        # database driver may support them (e.g. cx_Oracle).
 51        if kparams is not None and not self.db.features.supports_callproc_kwargs:
 52            raise NotSupportedError(
 53                "Keyword parameters for callproc are not supported on this "
 54                "database backend."
 55            )
 56        self.db.validate_no_broken_transaction()
 57        with self.db.wrap_database_errors:
 58            if params is None and kparams is None:
 59                return self.cursor.callproc(procname)
 60            elif kparams is None:
 61                return self.cursor.callproc(procname, params)
 62            else:
 63                params = params or ()
 64                return self.cursor.callproc(procname, params, kparams)
 65
 66    def execute(self, sql, params=None):
 67        return self._execute_with_wrappers(
 68            sql, params, many=False, executor=self._execute
 69        )
 70
 71    def executemany(self, sql, param_list):
 72        return self._execute_with_wrappers(
 73            sql, param_list, many=True, executor=self._executemany
 74        )
 75
 76    def _execute_with_wrappers(self, sql, params, many, executor):
 77        context = {"connection": self.db, "cursor": self}
 78        for wrapper in reversed(self.db.execute_wrappers):
 79            executor = functools.partial(wrapper, executor)
 80        return executor(sql, params, many, context)
 81
 82    def _execute(self, sql, params, *ignored_wrapper_args):
 83        self.db.validate_no_broken_transaction()
 84        with self.db.wrap_database_errors:
 85            if params is None:
 86                # params default might be backend specific.
 87                return self.cursor.execute(sql)
 88            else:
 89                return self.cursor.execute(sql, params)
 90
 91    def _executemany(self, sql, param_list, *ignored_wrapper_args):
 92        self.db.validate_no_broken_transaction()
 93        with self.db.wrap_database_errors:
 94            return self.cursor.executemany(sql, param_list)
 95
 96
 97class CursorDebugWrapper(CursorWrapper):
 98    # XXX callproc isn't instrumented at this time.
 99
100    def execute(self, sql, params=None):
101        with self.debug_sql(sql, params, use_last_executed_query=True):
102            return super().execute(sql, params)
103
104    def executemany(self, sql, param_list):
105        with self.debug_sql(sql, param_list, many=True):
106            return super().executemany(sql, param_list)
107
108    @contextmanager
109    def debug_sql(
110        self, sql=None, params=None, use_last_executed_query=False, many=False
111    ):
112        start = time.monotonic()
113        try:
114            yield
115        finally:
116            stop = time.monotonic()
117            duration = stop - start
118            if use_last_executed_query:
119                sql = self.db.ops.last_executed_query(self.cursor, sql, params)
120            try:
121                times = len(params) if many else ""
122            except TypeError:
123                # params could be an iterator.
124                times = "?"
125            self.db.queries_log.append(
126                {
127                    "sql": f"{times} times: {sql}" if many else sql,
128                    "time": f"{duration:.3f}",
129                }
130            )
131            logger.debug(
132                "(%.3f) %s; args=%s",
133                duration,
134                sql,
135                params,
136                extra={
137                    "duration": duration,
138                    "sql": sql,
139                    "params": params,
140                },
141            )
142
143
144@contextmanager
145def debug_transaction(connection, sql):
146    start = time.monotonic()
147    try:
148        yield
149    finally:
150        if connection.queries_logged:
151            stop = time.monotonic()
152            duration = stop - start
153            connection.queries_log.append(
154                {
155                    "sql": f"{sql}",
156                    "time": f"{duration:.3f}",
157                }
158            )
159            logger.debug(
160                "(%.3f) %s; args=%s",
161                duration,
162                sql,
163                None,
164                extra={
165                    "duration": duration,
166                    "sql": sql,
167                },
168            )
169
170
171def split_tzname_delta(tzname):
172    """
173    Split a time zone name into a 3-tuple of (name, sign, offset).
174    """
175    for sign in ["+", "-"]:
176        if sign in tzname:
177            name, offset = tzname.rsplit(sign, 1)
178            if offset and parse_time(offset):
179                return name, sign, offset
180    return tzname, None, None
181
182
183###############################################
184# Converters from database (string) to Python #
185###############################################
186
187
188def typecast_date(s):
189    return (
190        datetime.date(*map(int, s.split("-"))) if s else None
191    )  # return None if s is null
192
193
194def typecast_time(s):  # does NOT store time zone information
195    if not s:
196        return None
197    hour, minutes, seconds = s.split(":")
198    if "." in seconds:  # check whether seconds have a fractional part
199        seconds, microseconds = seconds.split(".")
200    else:
201        microseconds = "0"
202    return datetime.time(
203        int(hour), int(minutes), int(seconds), int((microseconds + "000000")[:6])
204    )
205
206
207def typecast_timestamp(s):  # does NOT store time zone information
208    # "2005-07-29 15:48:00.590358-05"
209    # "2005-07-29 09:56:00-05"
210    if not s:
211        return None
212    if " " not in s:
213        return typecast_date(s)
214    d, t = s.split()
215    # Remove timezone information.
216    if "-" in t:
217        t, _ = t.split("-", 1)
218    elif "+" in t:
219        t, _ = t.split("+", 1)
220    dates = d.split("-")
221    times = t.split(":")
222    seconds = times[2]
223    if "." in seconds:  # check whether seconds have a fractional part
224        seconds, microseconds = seconds.split(".")
225    else:
226        microseconds = "0"
227    return datetime.datetime(
228        int(dates[0]),
229        int(dates[1]),
230        int(dates[2]),
231        int(times[0]),
232        int(times[1]),
233        int(seconds),
234        int((microseconds + "000000")[:6]),
235    )
236
237
238###############################################
239# Converters from Python to database (string) #
240###############################################
241
242
243def split_identifier(identifier):
244    """
245    Split an SQL identifier into a two element tuple of (namespace, name).
246
247    The identifier could be a table, column, or sequence name might be prefixed
248    by a namespace.
249    """
250    try:
251        namespace, name = identifier.split('"."')
252    except ValueError:
253        namespace, name = "", identifier
254    return namespace.strip('"'), name.strip('"')
255
256
257def truncate_name(identifier, length=None, hash_len=4):
258    """
259    Shorten an SQL identifier to a repeatable mangled version with the given
260    length.
261
262    If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
263    truncate the table portion only.
264    """
265    namespace, name = split_identifier(identifier)
266
267    if length is None or len(name) <= length:
268        return identifier
269
270    digest = names_digest(name, length=hash_len)
271    return "{}{}{}".format(
272        f'{namespace}"."' if namespace else "",
273        name[: length - hash_len],
274        digest,
275    )
276
277
278def names_digest(*args, length):
279    """
280    Generate a 32-bit digest of a set of arguments that can be used to shorten
281    identifying names.
282    """
283    h = md5(usedforsecurity=False)
284    for arg in args:
285        h.update(arg.encode())
286    return h.hexdigest()[:length]
287
288
289def format_number(value, max_digits, decimal_places):
290    """
291    Format a number into a string with the requisite number of digits and
292    decimal places.
293    """
294    if value is None:
295        return None
296    context = decimal.getcontext().copy()
297    if max_digits is not None:
298        context.prec = max_digits
299    if decimal_places is not None:
300        value = value.quantize(
301            decimal.Decimal(1).scaleb(-decimal_places), context=context
302        )
303    else:
304        context.traps[decimal.Rounded] = 1
305        value = context.create_decimal(value)
306    return f"{value:f}"
307
308
309def strip_quotes(table_name):
310    """
311    Strip quotes off of quoted table names to make them safe for use in index
312    names, sequence names, etc. For example '"USER"."TABLE"' (an Oracle naming
313    scheme) becomes 'USER"."TABLE'.
314    """
315    has_quotes = table_name.startswith('"') and table_name.endswith('"')
316    return table_name[1:-1] if has_quotes else table_name