1from __future__ import annotations
2
3import datetime
4import decimal
5import functools
6import logging
7import time
8from collections.abc import Generator, Iterator, Mapping, Sequence
9from contextlib import contextmanager
10from hashlib import md5
11from types import TracebackType
12from typing import TYPE_CHECKING, Any, Protocol, Self
13
14from plain.models.db import NotSupportedError
15from plain.models.otel import db_span
16from plain.utils.dateparse import parse_time
17
18if TYPE_CHECKING:
19 from plain.models.backends.base.base import BaseDatabaseWrapper
20
21logger = logging.getLogger("plain.models.backends")
22
23
24class DBAPICursor(Protocol):
25 """Protocol for DB-API 2.0 (PEP 249) cursor objects."""
26
27 @property
28 def description(self) -> Sequence[Any] | None:
29 """Column descriptions from the last query."""
30 ...
31
32 @property
33 def rowcount(self) -> int:
34 """Number of rows affected by the last query."""
35 ...
36
37 @property
38 def lastrowid(self) -> int:
39 """ID of the last inserted row (if applicable)."""
40 ...
41
42 def close(self) -> None:
43 """Close the cursor."""
44 ...
45
46 def callproc(self, procname: str, *args: Any, **kwargs: Any) -> Any:
47 """Call a stored database procedure (optional in DB-API 2.0)."""
48 ...
49
50 def execute(
51 self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
52 ) -> Any:
53 """Execute a database operation."""
54 ...
55
56 def executemany(self, sql: str, params: Sequence[Sequence[Any]]) -> Any:
57 """Execute a database operation multiple times."""
58 ...
59
60 def fetchone(self) -> tuple[Any, ...] | None:
61 """Fetch the next row of a query result set."""
62 ...
63
64 def fetchmany(self, size: int = 0) -> list[tuple[Any, ...]]:
65 """Fetch the next set of rows of a query result set."""
66 ...
67
68 def fetchall(self) -> list[tuple[Any, ...]]:
69 """Fetch all remaining rows of a query result set."""
70 ...
71
72 def __iter__(self) -> Iterator[tuple[Any, ...]]:
73 """Iterate over rows in the result set."""
74 ...
75
76
77class CursorWrapper:
78 def __init__(self, cursor: DBAPICursor, db: BaseDatabaseWrapper) -> None:
79 self.cursor = cursor
80 self.db = db
81
82 WRAP_ERROR_ATTRS = frozenset(["fetchone", "fetchmany", "fetchall", "nextset"])
83
84 def __getattr__(self, attr: str) -> Any:
85 cursor_attr = getattr(self.cursor, attr)
86 if attr in CursorWrapper.WRAP_ERROR_ATTRS:
87 return self.db.wrap_database_errors(cursor_attr)
88 else:
89 return cursor_attr
90
91 def __iter__(self) -> Iterator[tuple[Any, ...]]:
92 with self.db.wrap_database_errors:
93 yield from self.cursor
94
95 def __enter__(self) -> Self:
96 return self
97
98 def __exit__(
99 self,
100 type: type[BaseException] | None,
101 value: BaseException | None,
102 traceback: TracebackType | None,
103 ) -> None:
104 # Close instead of passing through to avoid backend-specific behavior
105 # (#17671). Catch errors liberally because errors in cleanup code
106 # aren't useful.
107 try:
108 self.close()
109 except self.db.Database.Error:
110 pass
111
112 # The following methods cannot be implemented in __getattr__, because the
113 # code must run when the method is invoked, not just when it is accessed.
114
115 def callproc(
116 self,
117 procname: str,
118 params: Sequence[Any] | None = None,
119 kparams: Mapping[str, Any] | None = None,
120 ) -> Any:
121 # Keyword parameters for callproc aren't supported in PEP 249, but the
122 # database driver may support them (e.g. cx_Oracle).
123 if kparams is not None and not self.db.features.supports_callproc_kwargs:
124 raise NotSupportedError(
125 "Keyword parameters for callproc are not supported on this "
126 "database backend."
127 )
128 self.db.validate_no_broken_transaction()
129 with self.db.wrap_database_errors:
130 if params is None and kparams is None:
131 return self.cursor.callproc(procname)
132 elif kparams is None:
133 return self.cursor.callproc(procname, params)
134 else:
135 params = params or ()
136 return self.cursor.callproc(procname, params, kparams)
137
138 def execute(
139 self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
140 ) -> Self:
141 return self._execute_with_wrappers(
142 sql, params, many=False, executor=self._execute
143 )
144
145 def executemany(self, sql: str, param_list: Sequence[Sequence[Any]]) -> Self:
146 return self._execute_with_wrappers(
147 sql, param_list, many=True, executor=self._executemany
148 )
149
150 def _execute_with_wrappers(
151 self, sql: str, params: Any, many: bool, executor: Any
152 ) -> Self:
153 context: dict[str, Any] = {"connection": self.db, "cursor": self}
154 for wrapper in reversed(self.db.execute_wrappers):
155 executor = functools.partial(wrapper, executor)
156 executor(sql, params, many, context)
157 return self
158
159 def _execute(self, sql: str, params: Any, *ignored_wrapper_args: Any) -> None:
160 # Wrap in an OpenTelemetry span with standard attributes.
161 with db_span(self.db, sql, params=params):
162 self.db.validate_no_broken_transaction()
163 with self.db.wrap_database_errors:
164 if params is None:
165 self.cursor.execute(sql)
166 else:
167 self.cursor.execute(sql, params)
168
169 def _executemany(
170 self, sql: str, param_list: Any, *ignored_wrapper_args: Any
171 ) -> None:
172 with db_span(self.db, sql, many=True, params=param_list):
173 self.db.validate_no_broken_transaction()
174 with self.db.wrap_database_errors:
175 self.cursor.executemany(sql, param_list)
176
177
178class CursorDebugWrapper(CursorWrapper):
179 # XXX callproc isn't instrumented at this time.
180
181 def execute(
182 self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
183 ) -> Self:
184 with self.debug_sql(sql, params, use_last_executed_query=True):
185 super().execute(sql, params)
186 return self
187
188 def executemany(self, sql: str, param_list: Sequence[Sequence[Any]]) -> Self:
189 with self.debug_sql(sql, param_list, many=True):
190 super().executemany(sql, param_list)
191 return self
192
193 @contextmanager
194 def debug_sql(
195 self,
196 sql: str | None = None,
197 params: Any = None,
198 use_last_executed_query: bool = False,
199 many: bool = False,
200 ) -> Generator[None, None, None]:
201 start = time.monotonic()
202 try:
203 yield
204 finally:
205 stop = time.monotonic()
206 duration = stop - start
207 if use_last_executed_query:
208 sql = self.db.ops.last_executed_query(self.cursor, sql, params) # type: ignore[arg-type]
209 try:
210 times = len(params) if many else ""
211 except TypeError:
212 # params could be an iterator.
213 times = "?"
214 self.db.queries_log.append(
215 {
216 "sql": f"{times} times: {sql}" if many else sql,
217 "time": f"{duration:.3f}",
218 }
219 )
220 logger.debug(
221 "(%.3f) %s; args=%s",
222 duration,
223 sql,
224 params,
225 extra={
226 "duration": duration,
227 "sql": sql,
228 "params": params,
229 },
230 )
231
232
233@contextmanager
234def debug_transaction(
235 connection: BaseDatabaseWrapper, sql: str
236) -> Generator[None, None, None]:
237 start = time.monotonic()
238 try:
239 yield
240 finally:
241 if connection.queries_logged:
242 stop = time.monotonic()
243 duration = stop - start
244 connection.queries_log.append(
245 {
246 "sql": f"{sql}",
247 "time": f"{duration:.3f}",
248 }
249 )
250 logger.debug(
251 "(%.3f) %s; args=%s",
252 duration,
253 sql,
254 None,
255 extra={
256 "duration": duration,
257 "sql": sql,
258 },
259 )
260
261
262def split_tzname_delta(tzname: str) -> tuple[str, str | None, str | None]:
263 """
264 Split a time zone name into a 3-tuple of (name, sign, offset).
265 """
266 for sign in ["+", "-"]:
267 if sign in tzname:
268 name, offset = tzname.rsplit(sign, 1)
269 if offset and parse_time(offset):
270 return name, sign, offset
271 return tzname, None, None
272
273
274###############################################
275# Converters from database (string) to Python #
276###############################################
277
278
279def typecast_date(s: str | None) -> datetime.date | None:
280 return (
281 datetime.date(*map(int, s.split("-"))) if s else None
282 ) # return None if s is null
283
284
285def typecast_time(
286 s: str | None,
287) -> datetime.time | None: # does NOT store time zone information
288 if not s:
289 return None
290 hour, minutes, seconds = s.split(":")
291 if "." in seconds: # check whether seconds have a fractional part
292 seconds, microseconds = seconds.split(".")
293 else:
294 microseconds = "0"
295 return datetime.time(
296 int(hour), int(minutes), int(seconds), int((microseconds + "000000")[:6])
297 )
298
299
300def typecast_timestamp(
301 s: str | None,
302) -> datetime.date | datetime.datetime | None: # does NOT store time zone information
303 # "2005-07-29 15:48:00.590358-05"
304 # "2005-07-29 09:56:00-05"
305 if not s:
306 return None
307 if " " not in s:
308 return typecast_date(s)
309 d, t = s.split()
310 # Remove timezone information.
311 if "-" in t:
312 t, _ = t.split("-", 1)
313 elif "+" in t:
314 t, _ = t.split("+", 1)
315 dates = d.split("-")
316 times = t.split(":")
317 seconds = times[2]
318 if "." in seconds: # check whether seconds have a fractional part
319 seconds, microseconds = seconds.split(".")
320 else:
321 microseconds = "0"
322 return datetime.datetime(
323 int(dates[0]),
324 int(dates[1]),
325 int(dates[2]),
326 int(times[0]),
327 int(times[1]),
328 int(seconds),
329 int((microseconds + "000000")[:6]),
330 )
331
332
333###############################################
334# Converters from Python to database (string) #
335###############################################
336
337
338def split_identifier(identifier: str) -> tuple[str, str]:
339 """
340 Split an SQL identifier into a two element tuple of (namespace, name).
341
342 The identifier could be a table, column, or sequence name might be prefixed
343 by a namespace.
344 """
345 try:
346 namespace, name = identifier.split('"."')
347 except ValueError:
348 namespace, name = "", identifier
349 return namespace.strip('"'), name.strip('"')
350
351
352def truncate_name(identifier: str, length: int | None = None, hash_len: int = 4) -> str:
353 """
354 Shorten an SQL identifier to a repeatable mangled version with the given
355 length.
356
357 If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
358 truncate the table portion only.
359 """
360 namespace, name = split_identifier(identifier)
361
362 if length is None or len(name) <= length:
363 return identifier
364
365 digest = names_digest(name, length=hash_len)
366 return "{}{}{}".format(
367 f'{namespace}"."' if namespace else "",
368 name[: length - hash_len],
369 digest,
370 )
371
372
373def names_digest(*args: str, length: int) -> str:
374 """
375 Generate a 32-bit digest of a set of arguments that can be used to shorten
376 identifying names.
377 """
378 h = md5(usedforsecurity=False)
379 for arg in args:
380 h.update(arg.encode())
381 return h.hexdigest()[:length]
382
383
384def format_number(
385 value: decimal.Decimal | None, max_digits: int | None, decimal_places: int | None
386) -> str | None:
387 """
388 Format a number into a string with the requisite number of digits and
389 decimal places.
390 """
391 if value is None:
392 return None
393 context = decimal.getcontext().copy()
394 if max_digits is not None:
395 context.prec = max_digits
396 if decimal_places is not None:
397 value = value.quantize(
398 decimal.Decimal(1).scaleb(-decimal_places), context=context
399 )
400 else:
401 context.traps[decimal.Rounded] = True
402 value = context.create_decimal(value)
403 return f"{value:f}"
404
405
406def strip_quotes(table_name: str) -> str:
407 """
408 Strip quotes off of quoted table names to make them safe for use in index
409 names, sequence names, etc. For example '"USER"."TABLE"' (an Oracle naming
410 scheme) becomes 'USER"."TABLE'.
411 """
412 has_quotes = table_name.startswith('"') and table_name.endswith('"')
413 return table_name[1:-1] if has_quotes else table_name