1import datetime
2import decimal
3import functools
4import logging
5import time
6from contextlib import contextmanager
7from hashlib import md5
8
9from plain.models.db import NotSupportedError
10from plain.utils.dateparse import parse_time
11
12logger = logging.getLogger("plain.models.backends")
13
14
15class CursorWrapper:
16 def __init__(self, cursor, db):
17 self.cursor = cursor
18 self.db = db
19
20 WRAP_ERROR_ATTRS = frozenset(["fetchone", "fetchmany", "fetchall", "nextset"])
21
22 def __getattr__(self, attr):
23 cursor_attr = getattr(self.cursor, attr)
24 if attr in CursorWrapper.WRAP_ERROR_ATTRS:
25 return self.db.wrap_database_errors(cursor_attr)
26 else:
27 return cursor_attr
28
29 def __iter__(self):
30 with self.db.wrap_database_errors:
31 yield from self.cursor
32
33 def __enter__(self):
34 return self
35
36 def __exit__(self, type, value, traceback):
37 # Close instead of passing through to avoid backend-specific behavior
38 # (#17671). Catch errors liberally because errors in cleanup code
39 # aren't useful.
40 try:
41 self.close()
42 except self.db.Database.Error:
43 pass
44
45 # The following methods cannot be implemented in __getattr__, because the
46 # code must run when the method is invoked, not just when it is accessed.
47
48 def callproc(self, procname, params=None, kparams=None):
49 # Keyword parameters for callproc aren't supported in PEP 249, but the
50 # database driver may support them (e.g. cx_Oracle).
51 if kparams is not None and not self.db.features.supports_callproc_kwargs:
52 raise NotSupportedError(
53 "Keyword parameters for callproc are not supported on this "
54 "database backend."
55 )
56 self.db.validate_no_broken_transaction()
57 with self.db.wrap_database_errors:
58 if params is None and kparams is None:
59 return self.cursor.callproc(procname)
60 elif kparams is None:
61 return self.cursor.callproc(procname, params)
62 else:
63 params = params or ()
64 return self.cursor.callproc(procname, params, kparams)
65
66 def execute(self, sql, params=None):
67 return self._execute_with_wrappers(
68 sql, params, many=False, executor=self._execute
69 )
70
71 def executemany(self, sql, param_list):
72 return self._execute_with_wrappers(
73 sql, param_list, many=True, executor=self._executemany
74 )
75
76 def _execute_with_wrappers(self, sql, params, many, executor):
77 context = {"connection": self.db, "cursor": self}
78 for wrapper in reversed(self.db.execute_wrappers):
79 executor = functools.partial(wrapper, executor)
80 return executor(sql, params, many, context)
81
82 def _execute(self, sql, params, *ignored_wrapper_args):
83 self.db.validate_no_broken_transaction()
84 with self.db.wrap_database_errors:
85 if params is None:
86 # params default might be backend specific.
87 return self.cursor.execute(sql)
88 else:
89 return self.cursor.execute(sql, params)
90
91 def _executemany(self, sql, param_list, *ignored_wrapper_args):
92 self.db.validate_no_broken_transaction()
93 with self.db.wrap_database_errors:
94 return self.cursor.executemany(sql, param_list)
95
96
97class CursorDebugWrapper(CursorWrapper):
98 # XXX callproc isn't instrumented at this time.
99
100 def execute(self, sql, params=None):
101 with self.debug_sql(sql, params, use_last_executed_query=True):
102 return super().execute(sql, params)
103
104 def executemany(self, sql, param_list):
105 with self.debug_sql(sql, param_list, many=True):
106 return super().executemany(sql, param_list)
107
108 @contextmanager
109 def debug_sql(
110 self, sql=None, params=None, use_last_executed_query=False, many=False
111 ):
112 start = time.monotonic()
113 try:
114 yield
115 finally:
116 stop = time.monotonic()
117 duration = stop - start
118 if use_last_executed_query:
119 sql = self.db.ops.last_executed_query(self.cursor, sql, params)
120 try:
121 times = len(params) if many else ""
122 except TypeError:
123 # params could be an iterator.
124 times = "?"
125 self.db.queries_log.append(
126 {
127 "sql": f"{times} times: {sql}" if many else sql,
128 "time": f"{duration:.3f}",
129 }
130 )
131 logger.debug(
132 "(%.3f) %s; args=%s",
133 duration,
134 sql,
135 params,
136 extra={
137 "duration": duration,
138 "sql": sql,
139 "params": params,
140 },
141 )
142
143
144@contextmanager
145def debug_transaction(connection, sql):
146 start = time.monotonic()
147 try:
148 yield
149 finally:
150 if connection.queries_logged:
151 stop = time.monotonic()
152 duration = stop - start
153 connection.queries_log.append(
154 {
155 "sql": f"{sql}",
156 "time": f"{duration:.3f}",
157 }
158 )
159 logger.debug(
160 "(%.3f) %s; args=%s",
161 duration,
162 sql,
163 None,
164 extra={
165 "duration": duration,
166 "sql": sql,
167 },
168 )
169
170
171def split_tzname_delta(tzname):
172 """
173 Split a time zone name into a 3-tuple of (name, sign, offset).
174 """
175 for sign in ["+", "-"]:
176 if sign in tzname:
177 name, offset = tzname.rsplit(sign, 1)
178 if offset and parse_time(offset):
179 return name, sign, offset
180 return tzname, None, None
181
182
183###############################################
184# Converters from database (string) to Python #
185###############################################
186
187
188def typecast_date(s):
189 return (
190 datetime.date(*map(int, s.split("-"))) if s else None
191 ) # return None if s is null
192
193
194def typecast_time(s): # does NOT store time zone information
195 if not s:
196 return None
197 hour, minutes, seconds = s.split(":")
198 if "." in seconds: # check whether seconds have a fractional part
199 seconds, microseconds = seconds.split(".")
200 else:
201 microseconds = "0"
202 return datetime.time(
203 int(hour), int(minutes), int(seconds), int((microseconds + "000000")[:6])
204 )
205
206
207def typecast_timestamp(s): # does NOT store time zone information
208 # "2005-07-29 15:48:00.590358-05"
209 # "2005-07-29 09:56:00-05"
210 if not s:
211 return None
212 if " " not in s:
213 return typecast_date(s)
214 d, t = s.split()
215 # Remove timezone information.
216 if "-" in t:
217 t, _ = t.split("-", 1)
218 elif "+" in t:
219 t, _ = t.split("+", 1)
220 dates = d.split("-")
221 times = t.split(":")
222 seconds = times[2]
223 if "." in seconds: # check whether seconds have a fractional part
224 seconds, microseconds = seconds.split(".")
225 else:
226 microseconds = "0"
227 return datetime.datetime(
228 int(dates[0]),
229 int(dates[1]),
230 int(dates[2]),
231 int(times[0]),
232 int(times[1]),
233 int(seconds),
234 int((microseconds + "000000")[:6]),
235 )
236
237
238###############################################
239# Converters from Python to database (string) #
240###############################################
241
242
243def split_identifier(identifier):
244 """
245 Split an SQL identifier into a two element tuple of (namespace, name).
246
247 The identifier could be a table, column, or sequence name might be prefixed
248 by a namespace.
249 """
250 try:
251 namespace, name = identifier.split('"."')
252 except ValueError:
253 namespace, name = "", identifier
254 return namespace.strip('"'), name.strip('"')
255
256
257def truncate_name(identifier, length=None, hash_len=4):
258 """
259 Shorten an SQL identifier to a repeatable mangled version with the given
260 length.
261
262 If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
263 truncate the table portion only.
264 """
265 namespace, name = split_identifier(identifier)
266
267 if length is None or len(name) <= length:
268 return identifier
269
270 digest = names_digest(name, length=hash_len)
271 return "{}{}{}".format(
272 f'{namespace}"."' if namespace else "",
273 name[: length - hash_len],
274 digest,
275 )
276
277
278def names_digest(*args, length):
279 """
280 Generate a 32-bit digest of a set of arguments that can be used to shorten
281 identifying names.
282 """
283 h = md5(usedforsecurity=False)
284 for arg in args:
285 h.update(arg.encode())
286 return h.hexdigest()[:length]
287
288
289def format_number(value, max_digits, decimal_places):
290 """
291 Format a number into a string with the requisite number of digits and
292 decimal places.
293 """
294 if value is None:
295 return None
296 context = decimal.getcontext().copy()
297 if max_digits is not None:
298 context.prec = max_digits
299 if decimal_places is not None:
300 value = value.quantize(
301 decimal.Decimal(1).scaleb(-decimal_places), context=context
302 )
303 else:
304 context.traps[decimal.Rounded] = 1
305 value = context.create_decimal(value)
306 return f"{value:f}"
307
308
309def strip_quotes(table_name):
310 """
311 Strip quotes off of quoted table names to make them safe for use in index
312 names, sequence names, etc. For example '"USER"."TABLE"' (an Oracle naming
313 scheme) becomes 'USER"."TABLE'.
314 """
315 has_quotes = table_name.startswith('"') and table_name.endswith('"')
316 return table_name[1:-1] if has_quotes else table_name