1import datetime
2import decimal
3import functools
4import logging
5import time
6from contextlib import contextmanager
7from hashlib import md5
8
9from plain.models.db import NotSupportedError
10from plain.utils.dateparse import parse_time
11
12logger = logging.getLogger("plain.models.backends")
13
14
15class CursorWrapper:
16 def __init__(self, cursor, db):
17 self.cursor = cursor
18 self.db = db
19
20 WRAP_ERROR_ATTRS = frozenset(["fetchone", "fetchmany", "fetchall", "nextset"])
21
22 def __getattr__(self, attr):
23 cursor_attr = getattr(self.cursor, attr)
24 if attr in CursorWrapper.WRAP_ERROR_ATTRS:
25 return self.db.wrap_database_errors(cursor_attr)
26 else:
27 return cursor_attr
28
29 def __iter__(self):
30 with self.db.wrap_database_errors:
31 yield from self.cursor
32
33 def __enter__(self):
34 return self
35
36 def __exit__(self, type, value, traceback):
37 # Close instead of passing through to avoid backend-specific behavior
38 # (#17671). Catch errors liberally because errors in cleanup code
39 # aren't useful.
40 try:
41 self.close()
42 except self.db.Database.Error:
43 pass
44
45 # The following methods cannot be implemented in __getattr__, because the
46 # code must run when the method is invoked, not just when it is accessed.
47
48 def callproc(self, procname, params=None, kparams=None):
49 # Keyword parameters for callproc aren't supported in PEP 249, but the
50 # database driver may support them (e.g. cx_Oracle).
51 if kparams is not None and not self.db.features.supports_callproc_kwargs:
52 raise NotSupportedError(
53 "Keyword parameters for callproc are not supported on this "
54 "database backend."
55 )
56 self.db.validate_no_broken_transaction()
57 with self.db.wrap_database_errors:
58 if params is None and kparams is None:
59 return self.cursor.callproc(procname)
60 elif kparams is None:
61 return self.cursor.callproc(procname, params)
62 else:
63 params = params or ()
64 return self.cursor.callproc(procname, params, kparams)
65
66 def execute(self, sql, params=None):
67 return self._execute_with_wrappers(
68 sql, params, many=False, executor=self._execute
69 )
70
71 def executemany(self, sql, param_list):
72 return self._execute_with_wrappers(
73 sql, param_list, many=True, executor=self._executemany
74 )
75
76 def _execute_with_wrappers(self, sql, params, many, executor):
77 context = {"connection": self.db, "cursor": self}
78 for wrapper in reversed(self.db.execute_wrappers):
79 executor = functools.partial(wrapper, executor)
80 return executor(sql, params, many, context)
81
82 def _execute(self, sql, params, *ignored_wrapper_args):
83 self.db.validate_no_broken_transaction()
84 with self.db.wrap_database_errors:
85 if params is None:
86 # params default might be backend specific.
87 return self.cursor.execute(sql)
88 else:
89 return self.cursor.execute(sql, params)
90
91 def _executemany(self, sql, param_list, *ignored_wrapper_args):
92 self.db.validate_no_broken_transaction()
93 with self.db.wrap_database_errors:
94 return self.cursor.executemany(sql, param_list)
95
96
97class CursorDebugWrapper(CursorWrapper):
98 # XXX callproc isn't instrumented at this time.
99
100 def execute(self, sql, params=None):
101 with self.debug_sql(sql, params, use_last_executed_query=True):
102 return super().execute(sql, params)
103
104 def executemany(self, sql, param_list):
105 with self.debug_sql(sql, param_list, many=True):
106 return super().executemany(sql, param_list)
107
108 @contextmanager
109 def debug_sql(
110 self, sql=None, params=None, use_last_executed_query=False, many=False
111 ):
112 start = time.monotonic()
113 try:
114 yield
115 finally:
116 stop = time.monotonic()
117 duration = stop - start
118 if use_last_executed_query:
119 sql = self.db.ops.last_executed_query(self.cursor, sql, params)
120 try:
121 times = len(params) if many else ""
122 except TypeError:
123 # params could be an iterator.
124 times = "?"
125 self.db.queries_log.append(
126 {
127 "sql": f"{times} times: {sql}" if many else sql,
128 "time": "%.3f" % duration,
129 }
130 )
131 logger.debug(
132 "(%.3f) %s; args=%s; alias=%s",
133 duration,
134 sql,
135 params,
136 self.db.alias,
137 extra={
138 "duration": duration,
139 "sql": sql,
140 "params": params,
141 "alias": self.db.alias,
142 },
143 )
144
145
146@contextmanager
147def debug_transaction(connection, sql):
148 start = time.monotonic()
149 try:
150 yield
151 finally:
152 if connection.queries_logged:
153 stop = time.monotonic()
154 duration = stop - start
155 connection.queries_log.append(
156 {
157 "sql": "%s" % sql,
158 "time": "%.3f" % duration,
159 }
160 )
161 logger.debug(
162 "(%.3f) %s; args=%s; alias=%s",
163 duration,
164 sql,
165 None,
166 connection.alias,
167 extra={
168 "duration": duration,
169 "sql": sql,
170 "alias": connection.alias,
171 },
172 )
173
174
175def split_tzname_delta(tzname):
176 """
177 Split a time zone name into a 3-tuple of (name, sign, offset).
178 """
179 for sign in ["+", "-"]:
180 if sign in tzname:
181 name, offset = tzname.rsplit(sign, 1)
182 if offset and parse_time(offset):
183 return name, sign, offset
184 return tzname, None, None
185
186
187###############################################
188# Converters from database (string) to Python #
189###############################################
190
191
192def typecast_date(s):
193 return (
194 datetime.date(*map(int, s.split("-"))) if s else None
195 ) # return None if s is null
196
197
198def typecast_time(s): # does NOT store time zone information
199 if not s:
200 return None
201 hour, minutes, seconds = s.split(":")
202 if "." in seconds: # check whether seconds have a fractional part
203 seconds, microseconds = seconds.split(".")
204 else:
205 microseconds = "0"
206 return datetime.time(
207 int(hour), int(minutes), int(seconds), int((microseconds + "000000")[:6])
208 )
209
210
211def typecast_timestamp(s): # does NOT store time zone information
212 # "2005-07-29 15:48:00.590358-05"
213 # "2005-07-29 09:56:00-05"
214 if not s:
215 return None
216 if " " not in s:
217 return typecast_date(s)
218 d, t = s.split()
219 # Remove timezone information.
220 if "-" in t:
221 t, _ = t.split("-", 1)
222 elif "+" in t:
223 t, _ = t.split("+", 1)
224 dates = d.split("-")
225 times = t.split(":")
226 seconds = times[2]
227 if "." in seconds: # check whether seconds have a fractional part
228 seconds, microseconds = seconds.split(".")
229 else:
230 microseconds = "0"
231 return datetime.datetime(
232 int(dates[0]),
233 int(dates[1]),
234 int(dates[2]),
235 int(times[0]),
236 int(times[1]),
237 int(seconds),
238 int((microseconds + "000000")[:6]),
239 )
240
241
242###############################################
243# Converters from Python to database (string) #
244###############################################
245
246
247def split_identifier(identifier):
248 """
249 Split an SQL identifier into a two element tuple of (namespace, name).
250
251 The identifier could be a table, column, or sequence name might be prefixed
252 by a namespace.
253 """
254 try:
255 namespace, name = identifier.split('"."')
256 except ValueError:
257 namespace, name = "", identifier
258 return namespace.strip('"'), name.strip('"')
259
260
261def truncate_name(identifier, length=None, hash_len=4):
262 """
263 Shorten an SQL identifier to a repeatable mangled version with the given
264 length.
265
266 If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
267 truncate the table portion only.
268 """
269 namespace, name = split_identifier(identifier)
270
271 if length is None or len(name) <= length:
272 return identifier
273
274 digest = names_digest(name, length=hash_len)
275 return "{}{}{}".format(
276 '%s"."' % namespace if namespace else "",
277 name[: length - hash_len],
278 digest,
279 )
280
281
282def names_digest(*args, length):
283 """
284 Generate a 32-bit digest of a set of arguments that can be used to shorten
285 identifying names.
286 """
287 h = md5(usedforsecurity=False)
288 for arg in args:
289 h.update(arg.encode())
290 return h.hexdigest()[:length]
291
292
293def format_number(value, max_digits, decimal_places):
294 """
295 Format a number into a string with the requisite number of digits and
296 decimal places.
297 """
298 if value is None:
299 return None
300 context = decimal.getcontext().copy()
301 if max_digits is not None:
302 context.prec = max_digits
303 if decimal_places is not None:
304 value = value.quantize(
305 decimal.Decimal(1).scaleb(-decimal_places), context=context
306 )
307 else:
308 context.traps[decimal.Rounded] = 1
309 value = context.create_decimal(value)
310 return f"{value:f}"
311
312
313def strip_quotes(table_name):
314 """
315 Strip quotes off of quoted table names to make them safe for use in index
316 names, sequence names, etc. For example '"USER"."TABLE"' (an Oracle naming
317 scheme) becomes 'USER"."TABLE'.
318 """
319 has_quotes = table_name.startswith('"') and table_name.endswith('"')
320 return table_name[1:-1] if has_quotes else table_name