1from __future__ import annotations
2
3import datetime
4import decimal
5import functools
6import logging
7import time
8from collections.abc import Generator, Iterator, Mapping, Sequence
9from contextlib import contextmanager
10from hashlib import md5
11from types import TracebackType
12from typing import TYPE_CHECKING, Any, Protocol, Self
13
14from plain.models.db import NotSupportedError
15from plain.models.otel import db_span
16from plain.utils.dateparse import parse_time
17
18if TYPE_CHECKING:
19 from plain.models.backends.base.base import BaseDatabaseWrapper
20
21logger = logging.getLogger("plain.models.backends")
22
23
24class DBAPICursor(Protocol):
25 """Protocol for DB-API 2.0 (PEP 249) cursor objects."""
26
27 @property
28 def description(self) -> Sequence[Any] | None:
29 """Column descriptions from the last query."""
30 ...
31
32 @property
33 def rowcount(self) -> int:
34 """Number of rows affected by the last query."""
35 ...
36
37 @property
38 def lastrowid(self) -> int:
39 """ID of the last inserted row (if applicable)."""
40 ...
41
42 def close(self) -> None:
43 """Close the cursor."""
44 ...
45
46 def callproc(self, procname: str, *args: Any, **kwargs: Any) -> Any:
47 """Call a stored database procedure (optional in DB-API 2.0)."""
48 ...
49
50 def execute(
51 self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
52 ) -> Any:
53 """Execute a database operation."""
54 ...
55
56 def executemany(self, sql: str, params: Sequence[Sequence[Any]]) -> Any:
57 """Execute a database operation multiple times."""
58 ...
59
60 def fetchone(self) -> tuple[Any, ...] | None:
61 """Fetch the next row of a query result set."""
62 ...
63
64 def fetchmany(self, size: int | None = None) -> list[tuple[Any, ...]]:
65 """Fetch the next set of rows of a query result set."""
66 ...
67
68 def fetchall(self) -> list[tuple[Any, ...]]:
69 """Fetch all remaining rows of a query result set."""
70 ...
71
72 def __iter__(self) -> Iterator[tuple[Any, ...]]:
73 """Iterate over rows in the result set."""
74 ...
75
76
77class CursorWrapper:
78 def __init__(self, cursor: DBAPICursor, db: BaseDatabaseWrapper) -> None:
79 self.cursor = cursor
80 self.db = db
81
82 WRAP_ERROR_ATTRS = frozenset(["nextset"])
83
84 def __getattr__(self, attr: str) -> Any:
85 cursor_attr = getattr(self.cursor, attr)
86 if attr in CursorWrapper.WRAP_ERROR_ATTRS:
87 return self.db.wrap_database_errors(cursor_attr)
88 else:
89 return cursor_attr
90
91 def __iter__(self) -> Iterator[tuple[Any, ...]]:
92 with self.db.wrap_database_errors:
93 yield from self.cursor
94
95 def fetchone(self) -> tuple[Any, ...] | None:
96 with self.db.wrap_database_errors:
97 return self.cursor.fetchone()
98
99 def fetchmany(self, size: int | None = None) -> list[tuple[Any, ...]]:
100 with self.db.wrap_database_errors:
101 if size is None:
102 return self.cursor.fetchmany()
103 return self.cursor.fetchmany(size)
104
105 def fetchall(self) -> list[tuple[Any, ...]]:
106 with self.db.wrap_database_errors:
107 return self.cursor.fetchall()
108
109 def __enter__(self) -> Self:
110 return self
111
112 def __exit__(
113 self,
114 type: type[BaseException] | None,
115 value: BaseException | None,
116 traceback: TracebackType | None,
117 ) -> None:
118 # Close instead of passing through to avoid backend-specific behavior
119 # (#17671). Catch errors liberally because errors in cleanup code
120 # aren't useful.
121 try:
122 self.close()
123 except self.db.Database.Error:
124 pass
125
126 # The following methods cannot be implemented in __getattr__, because the
127 # code must run when the method is invoked, not just when it is accessed.
128
129 def callproc(
130 self,
131 procname: str,
132 params: Sequence[Any] | None = None,
133 kparams: Mapping[str, Any] | None = None,
134 ) -> Any:
135 # Keyword parameters for callproc aren't supported in PEP 249, but the
136 # database driver may support them (e.g. cx_Oracle).
137 if kparams is not None and not self.db.features.supports_callproc_kwargs:
138 raise NotSupportedError(
139 "Keyword parameters for callproc are not supported on this "
140 "database backend."
141 )
142 self.db.validate_no_broken_transaction()
143 with self.db.wrap_database_errors:
144 if params is None and kparams is None:
145 return self.cursor.callproc(procname)
146 elif kparams is None:
147 return self.cursor.callproc(procname, params)
148 else:
149 params = params or ()
150 return self.cursor.callproc(procname, params, kparams)
151
152 def execute(
153 self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
154 ) -> Self:
155 return self._execute_with_wrappers(
156 sql, params, many=False, executor=self._execute
157 )
158
159 def executemany(self, sql: str, param_list: Sequence[Sequence[Any]]) -> Self:
160 return self._execute_with_wrappers(
161 sql, param_list, many=True, executor=self._executemany
162 )
163
164 def _execute_with_wrappers(
165 self, sql: str, params: Any, many: bool, executor: Any
166 ) -> Self:
167 context: dict[str, Any] = {"connection": self.db, "cursor": self}
168 for wrapper in reversed(self.db.execute_wrappers):
169 executor = functools.partial(wrapper, executor)
170 executor(sql, params, many, context)
171 return self
172
173 def _execute(self, sql: str, params: Any, *ignored_wrapper_args: Any) -> None:
174 # Wrap in an OpenTelemetry span with standard attributes.
175 with db_span(self.db, sql, params=params):
176 self.db.validate_no_broken_transaction()
177 with self.db.wrap_database_errors:
178 if params is None:
179 self.cursor.execute(sql)
180 else:
181 self.cursor.execute(sql, params)
182
183 def _executemany(
184 self, sql: str, param_list: Any, *ignored_wrapper_args: Any
185 ) -> None:
186 with db_span(self.db, sql, many=True, params=param_list):
187 self.db.validate_no_broken_transaction()
188 with self.db.wrap_database_errors:
189 self.cursor.executemany(sql, param_list)
190
191
192class CursorDebugWrapper(CursorWrapper):
193 # XXX callproc isn't instrumented at this time.
194
195 def execute(
196 self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
197 ) -> Self:
198 with self.debug_sql(sql, params, use_last_executed_query=True):
199 super().execute(sql, params)
200 return self
201
202 def executemany(self, sql: str, param_list: Sequence[Sequence[Any]]) -> Self:
203 with self.debug_sql(sql, param_list, many=True):
204 super().executemany(sql, param_list)
205 return self
206
207 @contextmanager
208 def debug_sql(
209 self,
210 sql: str | None = None,
211 params: Any = None,
212 use_last_executed_query: bool = False,
213 many: bool = False,
214 ) -> Generator[None, None, None]:
215 start = time.monotonic()
216 try:
217 yield
218 finally:
219 stop = time.monotonic()
220 duration = stop - start
221 if use_last_executed_query:
222 sql = self.db.ops.last_executed_query(self.cursor, sql, params) # type: ignore[arg-type]
223 try:
224 times = len(params) if many else ""
225 except TypeError:
226 # params could be an iterator.
227 times = "?"
228 self.db.queries_log.append(
229 {
230 "sql": f"{times} times: {sql}" if many else sql,
231 "time": f"{duration:.3f}",
232 }
233 )
234 logger.debug(
235 "(%.3f) %s; args=%s",
236 duration,
237 sql,
238 params,
239 extra={
240 "duration": duration,
241 "sql": sql,
242 "params": params,
243 },
244 )
245
246
247@contextmanager
248def debug_transaction(
249 connection: BaseDatabaseWrapper, sql: str
250) -> Generator[None, None, None]:
251 start = time.monotonic()
252 try:
253 yield
254 finally:
255 if connection.queries_logged:
256 stop = time.monotonic()
257 duration = stop - start
258 connection.queries_log.append(
259 {
260 "sql": f"{sql}",
261 "time": f"{duration:.3f}",
262 }
263 )
264 logger.debug(
265 "(%.3f) %s; args=%s",
266 duration,
267 sql,
268 None,
269 extra={
270 "duration": duration,
271 "sql": sql,
272 },
273 )
274
275
276def split_tzname_delta(tzname: str) -> tuple[str, str | None, str | None]:
277 """
278 Split a time zone name into a 3-tuple of (name, sign, offset).
279 """
280 for sign in ["+", "-"]:
281 if sign in tzname:
282 name, offset = tzname.rsplit(sign, 1)
283 if offset and parse_time(offset):
284 return name, sign, offset
285 return tzname, None, None
286
287
288###############################################
289# Converters from database (string) to Python #
290###############################################
291
292
293def typecast_date(s: str | None) -> datetime.date | None:
294 return (
295 datetime.date(*map(int, s.split("-"))) if s else None
296 ) # return None if s is null
297
298
299def typecast_time(
300 s: str | None,
301) -> datetime.time | None: # does NOT store time zone information
302 if not s:
303 return None
304 hour, minutes, seconds = s.split(":")
305 if "." in seconds: # check whether seconds have a fractional part
306 seconds, microseconds = seconds.split(".")
307 else:
308 microseconds = "0"
309 return datetime.time(
310 int(hour), int(minutes), int(seconds), int((microseconds + "000000")[:6])
311 )
312
313
314def typecast_timestamp(
315 s: str | None,
316) -> datetime.date | datetime.datetime | None: # does NOT store time zone information
317 # "2005-07-29 15:48:00.590358-05"
318 # "2005-07-29 09:56:00-05"
319 if not s:
320 return None
321 if " " not in s:
322 return typecast_date(s)
323 d, t = s.split()
324 # Remove timezone information.
325 if "-" in t:
326 t, _ = t.split("-", 1)
327 elif "+" in t:
328 t, _ = t.split("+", 1)
329 dates = d.split("-")
330 times = t.split(":")
331 seconds = times[2]
332 if "." in seconds: # check whether seconds have a fractional part
333 seconds, microseconds = seconds.split(".")
334 else:
335 microseconds = "0"
336 return datetime.datetime(
337 int(dates[0]),
338 int(dates[1]),
339 int(dates[2]),
340 int(times[0]),
341 int(times[1]),
342 int(seconds),
343 int((microseconds + "000000")[:6]),
344 )
345
346
347###############################################
348# Converters from Python to database (string) #
349###############################################
350
351
352def split_identifier(identifier: str) -> tuple[str, str]:
353 """
354 Split an SQL identifier into a two element tuple of (namespace, name).
355
356 The identifier could be a table, column, or sequence name might be prefixed
357 by a namespace.
358 """
359 try:
360 namespace, name = identifier.split('"."')
361 except ValueError:
362 namespace, name = "", identifier
363 return namespace.strip('"'), name.strip('"')
364
365
366def truncate_name(identifier: str, length: int | None = None, hash_len: int = 4) -> str:
367 """
368 Shorten an SQL identifier to a repeatable mangled version with the given
369 length.
370
371 If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
372 truncate the table portion only.
373 """
374 namespace, name = split_identifier(identifier)
375
376 if length is None or len(name) <= length:
377 return identifier
378
379 digest = names_digest(name, length=hash_len)
380 return "{}{}{}".format(
381 f'{namespace}"."' if namespace else "",
382 name[: length - hash_len],
383 digest,
384 )
385
386
387def names_digest(*args: str, length: int) -> str:
388 """
389 Generate a 32-bit digest of a set of arguments that can be used to shorten
390 identifying names.
391 """
392 h = md5(usedforsecurity=False)
393 for arg in args:
394 h.update(arg.encode())
395 return h.hexdigest()[:length]
396
397
398def format_number(
399 value: decimal.Decimal | None, max_digits: int | None, decimal_places: int | None
400) -> str | None:
401 """
402 Format a number into a string with the requisite number of digits and
403 decimal places.
404 """
405 if value is None:
406 return None
407 context = decimal.getcontext().copy()
408 if max_digits is not None:
409 context.prec = max_digits
410 if decimal_places is not None:
411 value = value.quantize(
412 decimal.Decimal(1).scaleb(-decimal_places), context=context
413 )
414 else:
415 context.traps[decimal.Rounded] = True
416 value = context.create_decimal(value)
417 return f"{value:f}"
418
419
420def strip_quotes(table_name: str) -> str:
421 """
422 Strip quotes off of quoted table names to make them safe for use in index
423 names, sequence names, etc. For example '"USER"."TABLE"' (an Oracle naming
424 scheme) becomes 'USER"."TABLE'.
425 """
426 has_quotes = table_name.startswith('"') and table_name.endswith('"')
427 return table_name[1:-1] if has_quotes else table_name