1from __future__ import annotations
2
3import datetime
4import decimal
5import functools
6import logging
7import time
8from collections.abc import Generator, Iterator
9from contextlib import contextmanager
10from hashlib import md5
11from typing import TYPE_CHECKING, Any
12
13from plain.models.db import NotSupportedError
14from plain.models.otel import db_span
15from plain.utils.dateparse import parse_time
16
17if TYPE_CHECKING:
18 from plain.models.backends.base.base import BaseDatabaseWrapper
19
20logger = logging.getLogger("plain.models.backends")
21
22
23class CursorWrapper:
24 def __init__(self, cursor: Any, db: Any) -> None:
25 self.cursor = cursor
26 self.db = db
27
28 WRAP_ERROR_ATTRS = frozenset(["fetchone", "fetchmany", "fetchall", "nextset"])
29
30 def __getattr__(self, attr: str) -> Any:
31 cursor_attr = getattr(self.cursor, attr)
32 if attr in CursorWrapper.WRAP_ERROR_ATTRS:
33 return self.db.wrap_database_errors(cursor_attr)
34 else:
35 return cursor_attr
36
37 def __iter__(self) -> Iterator[Any]:
38 with self.db.wrap_database_errors:
39 yield from self.cursor
40
41 def __enter__(self) -> CursorWrapper:
42 return self
43
44 def __exit__(self, type: Any, value: Any, traceback: Any) -> None:
45 # Close instead of passing through to avoid backend-specific behavior
46 # (#17671). Catch errors liberally because errors in cleanup code
47 # aren't useful.
48 try:
49 self.close() # type: ignore[attr-defined]
50 except self.db.Database.Error:
51 pass
52
53 # The following methods cannot be implemented in __getattr__, because the
54 # code must run when the method is invoked, not just when it is accessed.
55
56 def callproc(self, procname: str, params: Any = None, kparams: Any = None) -> Any:
57 # Keyword parameters for callproc aren't supported in PEP 249, but the
58 # database driver may support them (e.g. cx_Oracle).
59 if kparams is not None and not self.db.features.supports_callproc_kwargs:
60 raise NotSupportedError(
61 "Keyword parameters for callproc are not supported on this "
62 "database backend."
63 )
64 self.db.validate_no_broken_transaction()
65 with self.db.wrap_database_errors:
66 if params is None and kparams is None:
67 return self.cursor.callproc(procname)
68 elif kparams is None:
69 return self.cursor.callproc(procname, params)
70 else:
71 params = params or ()
72 return self.cursor.callproc(procname, params, kparams)
73
74 def execute(self, sql: str, params: Any = None) -> Any:
75 return self._execute_with_wrappers(
76 sql, params, many=False, executor=self._execute
77 )
78
79 def executemany(self, sql: str, param_list: Any) -> Any:
80 return self._execute_with_wrappers(
81 sql, param_list, many=True, executor=self._executemany
82 )
83
84 def _execute_with_wrappers(
85 self, sql: str, params: Any, many: bool, executor: Any
86 ) -> Any:
87 context: dict[str, Any] = {"connection": self.db, "cursor": self}
88 for wrapper in reversed(self.db.execute_wrappers):
89 executor = functools.partial(wrapper, executor)
90 return executor(sql, params, many, context)
91
92 def _execute(self, sql: str, params: Any, *ignored_wrapper_args: Any) -> Any:
93 # Wrap in an OpenTelemetry span with standard attributes.
94 with db_span(self.db, sql, params=params):
95 self.db.validate_no_broken_transaction()
96 with self.db.wrap_database_errors:
97 if params is None:
98 return self.cursor.execute(sql)
99 else:
100 return self.cursor.execute(sql, params)
101
102 def _executemany(
103 self, sql: str, param_list: Any, *ignored_wrapper_args: Any
104 ) -> Any:
105 with db_span(self.db, sql, many=True, params=param_list):
106 self.db.validate_no_broken_transaction()
107 with self.db.wrap_database_errors:
108 return self.cursor.executemany(sql, param_list)
109
110
111class CursorDebugWrapper(CursorWrapper):
112 # XXX callproc isn't instrumented at this time.
113
114 def execute(self, sql: str, params: Any = None) -> Any:
115 with self.debug_sql(sql, params, use_last_executed_query=True):
116 return super().execute(sql, params)
117
118 def executemany(self, sql: str, param_list: Any) -> Any:
119 with self.debug_sql(sql, param_list, many=True):
120 return super().executemany(sql, param_list)
121
122 @contextmanager
123 def debug_sql(
124 self,
125 sql: str | None = None,
126 params: Any = None,
127 use_last_executed_query: bool = False,
128 many: bool = False,
129 ) -> Generator[None, None, None]:
130 start = time.monotonic()
131 try:
132 yield
133 finally:
134 stop = time.monotonic()
135 duration = stop - start
136 if use_last_executed_query:
137 sql = self.db.ops.last_executed_query(self.cursor, sql, params)
138 try:
139 times = len(params) if many else "" # type: ignore[arg-type]
140 except TypeError:
141 # params could be an iterator.
142 times = "?"
143 self.db.queries_log.append(
144 {
145 "sql": f"{times} times: {sql}" if many else sql,
146 "time": f"{duration:.3f}",
147 }
148 )
149 logger.debug(
150 "(%.3f) %s; args=%s",
151 duration,
152 sql,
153 params,
154 extra={
155 "duration": duration,
156 "sql": sql,
157 "params": params,
158 },
159 )
160
161
162@contextmanager
163def debug_transaction(
164 connection: BaseDatabaseWrapper, sql: str
165) -> Generator[None, None, None]:
166 start = time.monotonic()
167 try:
168 yield
169 finally:
170 if connection.queries_logged:
171 stop = time.monotonic()
172 duration = stop - start
173 connection.queries_log.append(
174 {
175 "sql": f"{sql}",
176 "time": f"{duration:.3f}",
177 }
178 )
179 logger.debug(
180 "(%.3f) %s; args=%s",
181 duration,
182 sql,
183 None,
184 extra={
185 "duration": duration,
186 "sql": sql,
187 },
188 )
189
190
191def split_tzname_delta(tzname: str) -> tuple[str, str | None, str | None]:
192 """
193 Split a time zone name into a 3-tuple of (name, sign, offset).
194 """
195 for sign in ["+", "-"]:
196 if sign in tzname:
197 name, offset = tzname.rsplit(sign, 1)
198 if offset and parse_time(offset):
199 return name, sign, offset
200 return tzname, None, None
201
202
203###############################################
204# Converters from database (string) to Python #
205###############################################
206
207
208def typecast_date(s: str | None) -> datetime.date | None:
209 return (
210 datetime.date(*map(int, s.split("-"))) if s else None
211 ) # return None if s is null
212
213
214def typecast_time(
215 s: str | None,
216) -> datetime.time | None: # does NOT store time zone information
217 if not s:
218 return None
219 hour, minutes, seconds = s.split(":")
220 if "." in seconds: # check whether seconds have a fractional part
221 seconds, microseconds = seconds.split(".")
222 else:
223 microseconds = "0"
224 return datetime.time(
225 int(hour), int(minutes), int(seconds), int((microseconds + "000000")[:6])
226 )
227
228
229def typecast_timestamp(
230 s: str | None,
231) -> datetime.date | datetime.datetime | None: # does NOT store time zone information
232 # "2005-07-29 15:48:00.590358-05"
233 # "2005-07-29 09:56:00-05"
234 if not s:
235 return None
236 if " " not in s:
237 return typecast_date(s)
238 d, t = s.split()
239 # Remove timezone information.
240 if "-" in t:
241 t, _ = t.split("-", 1)
242 elif "+" in t:
243 t, _ = t.split("+", 1)
244 dates = d.split("-")
245 times = t.split(":")
246 seconds = times[2]
247 if "." in seconds: # check whether seconds have a fractional part
248 seconds, microseconds = seconds.split(".")
249 else:
250 microseconds = "0"
251 return datetime.datetime(
252 int(dates[0]),
253 int(dates[1]),
254 int(dates[2]),
255 int(times[0]),
256 int(times[1]),
257 int(seconds),
258 int((microseconds + "000000")[:6]),
259 )
260
261
262###############################################
263# Converters from Python to database (string) #
264###############################################
265
266
267def split_identifier(identifier: str) -> tuple[str, str]:
268 """
269 Split an SQL identifier into a two element tuple of (namespace, name).
270
271 The identifier could be a table, column, or sequence name might be prefixed
272 by a namespace.
273 """
274 try:
275 namespace, name = identifier.split('"."')
276 except ValueError:
277 namespace, name = "", identifier
278 return namespace.strip('"'), name.strip('"')
279
280
281def truncate_name(identifier: str, length: int | None = None, hash_len: int = 4) -> str:
282 """
283 Shorten an SQL identifier to a repeatable mangled version with the given
284 length.
285
286 If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
287 truncate the table portion only.
288 """
289 namespace, name = split_identifier(identifier)
290
291 if length is None or len(name) <= length:
292 return identifier
293
294 digest = names_digest(name, length=hash_len)
295 return "{}{}{}".format(
296 f'{namespace}"."' if namespace else "",
297 name[: length - hash_len],
298 digest,
299 )
300
301
302def names_digest(*args: str, length: int) -> str:
303 """
304 Generate a 32-bit digest of a set of arguments that can be used to shorten
305 identifying names.
306 """
307 h = md5(usedforsecurity=False)
308 for arg in args:
309 h.update(arg.encode())
310 return h.hexdigest()[:length]
311
312
313def format_number(
314 value: decimal.Decimal | None, max_digits: int | None, decimal_places: int | None
315) -> str | None:
316 """
317 Format a number into a string with the requisite number of digits and
318 decimal places.
319 """
320 if value is None:
321 return None
322 context = decimal.getcontext().copy()
323 if max_digits is not None:
324 context.prec = max_digits
325 if decimal_places is not None:
326 value = value.quantize(
327 decimal.Decimal(1).scaleb(-decimal_places), context=context
328 )
329 else:
330 context.traps[decimal.Rounded] = 1 # type: ignore[assignment]
331 value = context.create_decimal(value)
332 return f"{value:f}"
333
334
335def strip_quotes(table_name: str) -> str:
336 """
337 Strip quotes off of quoted table names to make them safe for use in index
338 names, sequence names, etc. For example '"USER"."TABLE"' (an Oracle naming
339 scheme) becomes 'USER"."TABLE'.
340 """
341 has_quotes = table_name.startswith('"') and table_name.endswith('"')
342 return table_name[1:-1] if has_quotes else table_name