1import datetime
2import decimal
3import functools
4import logging
5import time
6from contextlib import contextmanager
7from hashlib import md5
8
9from plain.models.db import NotSupportedError
10from plain.models.otel import db_span
11from plain.utils.dateparse import parse_time
12
13logger = logging.getLogger("plain.models.backends")
14
15
16class CursorWrapper:
17 def __init__(self, cursor, db):
18 self.cursor = cursor
19 self.db = db
20
21 WRAP_ERROR_ATTRS = frozenset(["fetchone", "fetchmany", "fetchall", "nextset"])
22
23 def __getattr__(self, attr):
24 cursor_attr = getattr(self.cursor, attr)
25 if attr in CursorWrapper.WRAP_ERROR_ATTRS:
26 return self.db.wrap_database_errors(cursor_attr)
27 else:
28 return cursor_attr
29
30 def __iter__(self):
31 with self.db.wrap_database_errors:
32 yield from self.cursor
33
34 def __enter__(self):
35 return self
36
37 def __exit__(self, type, value, traceback):
38 # Close instead of passing through to avoid backend-specific behavior
39 # (#17671). Catch errors liberally because errors in cleanup code
40 # aren't useful.
41 try:
42 self.close()
43 except self.db.Database.Error:
44 pass
45
46 # The following methods cannot be implemented in __getattr__, because the
47 # code must run when the method is invoked, not just when it is accessed.
48
49 def callproc(self, procname, params=None, kparams=None):
50 # Keyword parameters for callproc aren't supported in PEP 249, but the
51 # database driver may support them (e.g. cx_Oracle).
52 if kparams is not None and not self.db.features.supports_callproc_kwargs:
53 raise NotSupportedError(
54 "Keyword parameters for callproc are not supported on this "
55 "database backend."
56 )
57 self.db.validate_no_broken_transaction()
58 with self.db.wrap_database_errors:
59 if params is None and kparams is None:
60 return self.cursor.callproc(procname)
61 elif kparams is None:
62 return self.cursor.callproc(procname, params)
63 else:
64 params = params or ()
65 return self.cursor.callproc(procname, params, kparams)
66
67 def execute(self, sql, params=None):
68 return self._execute_with_wrappers(
69 sql, params, many=False, executor=self._execute
70 )
71
72 def executemany(self, sql, param_list):
73 return self._execute_with_wrappers(
74 sql, param_list, many=True, executor=self._executemany
75 )
76
77 def _execute_with_wrappers(self, sql, params, many, executor):
78 context = {"connection": self.db, "cursor": self}
79 for wrapper in reversed(self.db.execute_wrappers):
80 executor = functools.partial(wrapper, executor)
81 return executor(sql, params, many, context)
82
83 def _execute(self, sql, params, *ignored_wrapper_args):
84 # Wrap in an OpenTelemetry span with standard attributes.
85 with db_span(self.db, sql, params=params):
86 self.db.validate_no_broken_transaction()
87 with self.db.wrap_database_errors:
88 if params is None:
89 return self.cursor.execute(sql)
90 else:
91 return self.cursor.execute(sql, params)
92
93 def _executemany(self, sql, param_list, *ignored_wrapper_args):
94 with db_span(self.db, sql, many=True, params=param_list):
95 self.db.validate_no_broken_transaction()
96 with self.db.wrap_database_errors:
97 return self.cursor.executemany(sql, param_list)
98
99
100class CursorDebugWrapper(CursorWrapper):
101 # XXX callproc isn't instrumented at this time.
102
103 def execute(self, sql, params=None):
104 with self.debug_sql(sql, params, use_last_executed_query=True):
105 return super().execute(sql, params)
106
107 def executemany(self, sql, param_list):
108 with self.debug_sql(sql, param_list, many=True):
109 return super().executemany(sql, param_list)
110
111 @contextmanager
112 def debug_sql(
113 self, sql=None, params=None, use_last_executed_query=False, many=False
114 ):
115 start = time.monotonic()
116 try:
117 yield
118 finally:
119 stop = time.monotonic()
120 duration = stop - start
121 if use_last_executed_query:
122 sql = self.db.ops.last_executed_query(self.cursor, sql, params)
123 try:
124 times = len(params) if many else ""
125 except TypeError:
126 # params could be an iterator.
127 times = "?"
128 self.db.queries_log.append(
129 {
130 "sql": f"{times} times: {sql}" if many else sql,
131 "time": f"{duration:.3f}",
132 }
133 )
134 logger.debug(
135 "(%.3f) %s; args=%s",
136 duration,
137 sql,
138 params,
139 extra={
140 "duration": duration,
141 "sql": sql,
142 "params": params,
143 },
144 )
145
146
147@contextmanager
148def debug_transaction(connection, sql):
149 start = time.monotonic()
150 try:
151 yield
152 finally:
153 if connection.queries_logged:
154 stop = time.monotonic()
155 duration = stop - start
156 connection.queries_log.append(
157 {
158 "sql": f"{sql}",
159 "time": f"{duration:.3f}",
160 }
161 )
162 logger.debug(
163 "(%.3f) %s; args=%s",
164 duration,
165 sql,
166 None,
167 extra={
168 "duration": duration,
169 "sql": sql,
170 },
171 )
172
173
174def split_tzname_delta(tzname):
175 """
176 Split a time zone name into a 3-tuple of (name, sign, offset).
177 """
178 for sign in ["+", "-"]:
179 if sign in tzname:
180 name, offset = tzname.rsplit(sign, 1)
181 if offset and parse_time(offset):
182 return name, sign, offset
183 return tzname, None, None
184
185
186###############################################
187# Converters from database (string) to Python #
188###############################################
189
190
191def typecast_date(s):
192 return (
193 datetime.date(*map(int, s.split("-"))) if s else None
194 ) # return None if s is null
195
196
197def typecast_time(s): # does NOT store time zone information
198 if not s:
199 return None
200 hour, minutes, seconds = s.split(":")
201 if "." in seconds: # check whether seconds have a fractional part
202 seconds, microseconds = seconds.split(".")
203 else:
204 microseconds = "0"
205 return datetime.time(
206 int(hour), int(minutes), int(seconds), int((microseconds + "000000")[:6])
207 )
208
209
210def typecast_timestamp(s): # does NOT store time zone information
211 # "2005-07-29 15:48:00.590358-05"
212 # "2005-07-29 09:56:00-05"
213 if not s:
214 return None
215 if " " not in s:
216 return typecast_date(s)
217 d, t = s.split()
218 # Remove timezone information.
219 if "-" in t:
220 t, _ = t.split("-", 1)
221 elif "+" in t:
222 t, _ = t.split("+", 1)
223 dates = d.split("-")
224 times = t.split(":")
225 seconds = times[2]
226 if "." in seconds: # check whether seconds have a fractional part
227 seconds, microseconds = seconds.split(".")
228 else:
229 microseconds = "0"
230 return datetime.datetime(
231 int(dates[0]),
232 int(dates[1]),
233 int(dates[2]),
234 int(times[0]),
235 int(times[1]),
236 int(seconds),
237 int((microseconds + "000000")[:6]),
238 )
239
240
241###############################################
242# Converters from Python to database (string) #
243###############################################
244
245
246def split_identifier(identifier):
247 """
248 Split an SQL identifier into a two element tuple of (namespace, name).
249
250 The identifier could be a table, column, or sequence name might be prefixed
251 by a namespace.
252 """
253 try:
254 namespace, name = identifier.split('"."')
255 except ValueError:
256 namespace, name = "", identifier
257 return namespace.strip('"'), name.strip('"')
258
259
260def truncate_name(identifier, length=None, hash_len=4):
261 """
262 Shorten an SQL identifier to a repeatable mangled version with the given
263 length.
264
265 If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
266 truncate the table portion only.
267 """
268 namespace, name = split_identifier(identifier)
269
270 if length is None or len(name) <= length:
271 return identifier
272
273 digest = names_digest(name, length=hash_len)
274 return "{}{}{}".format(
275 f'{namespace}"."' if namespace else "",
276 name[: length - hash_len],
277 digest,
278 )
279
280
281def names_digest(*args, length):
282 """
283 Generate a 32-bit digest of a set of arguments that can be used to shorten
284 identifying names.
285 """
286 h = md5(usedforsecurity=False)
287 for arg in args:
288 h.update(arg.encode())
289 return h.hexdigest()[:length]
290
291
292def format_number(value, max_digits, decimal_places):
293 """
294 Format a number into a string with the requisite number of digits and
295 decimal places.
296 """
297 if value is None:
298 return None
299 context = decimal.getcontext().copy()
300 if max_digits is not None:
301 context.prec = max_digits
302 if decimal_places is not None:
303 value = value.quantize(
304 decimal.Decimal(1).scaleb(-decimal_places), context=context
305 )
306 else:
307 context.traps[decimal.Rounded] = 1
308 value = context.create_decimal(value)
309 return f"{value:f}"
310
311
312def strip_quotes(table_name):
313 """
314 Strip quotes off of quoted table names to make them safe for use in index
315 names, sequence names, etc. For example '"USER"."TABLE"' (an Oracle naming
316 scheme) becomes 'USER"."TABLE'.
317 """
318 has_quotes = table_name.startswith('"') and table_name.endswith('"')
319 return table_name[1:-1] if has_quotes else table_name