1from __future__ import annotations
2
3import functools
4import logging
5import time
6from collections.abc import Generator, Iterator, Mapping, Sequence
7from contextlib import contextmanager
8from hashlib import md5
9from types import TracebackType
10from typing import TYPE_CHECKING, Any, Self
11
12import psycopg
13
14from plain.postgres.db import NotSupportedError
15from plain.postgres.otel import db_span
16from plain.utils.dateparse import parse_time
17
18if TYPE_CHECKING:
19 from plain.postgres.connection import DatabaseConnection
20
21logger = logging.getLogger("plain.postgres.utils")
22
23
24def make_model_tuple(model: Any) -> tuple[str, str]:
25 """
26 Take a model or a string of the form "package_label.ModelName" and return a
27 corresponding ("package_label", "modelname") tuple. If a tuple is passed in,
28 assume it's a valid model tuple already and return it unchanged.
29 """
30 try:
31 if isinstance(model, tuple):
32 model_tuple = model
33 elif isinstance(model, str):
34 package_label, model_name = model.split(".")
35 model_tuple = package_label, model_name.lower()
36 else:
37 model_tuple = (
38 model.model_options.package_label,
39 model.model_options.model_name,
40 )
41 assert len(model_tuple) == 2
42 return model_tuple
43 except (ValueError, AssertionError):
44 raise ValueError(
45 f"Invalid model reference '{model}'. String model references "
46 "must be of the form 'package_label.ModelName'."
47 )
48
49
50def resolve_callables(
51 mapping: dict[str, Any],
52) -> Generator[tuple[str, Any]]:
53 """
54 Generate key/value pairs for the given mapping where the values are
55 evaluated if they're callable.
56 """
57 for k, v in mapping.items():
58 yield k, v() if callable(v) else v
59
60
61class CursorWrapper:
62 def __init__(self, cursor: Any, db: DatabaseConnection) -> None:
63 self.cursor = cursor
64 self.db = db
65
66 WRAP_ERROR_ATTRS = frozenset(["nextset"])
67
68 def __getattr__(self, attr: str) -> Any:
69 cursor_attr = getattr(self.cursor, attr)
70 if attr in CursorWrapper.WRAP_ERROR_ATTRS:
71 return self.db.wrap_database_errors(cursor_attr)
72 else:
73 return cursor_attr
74
75 def __iter__(self) -> Iterator[tuple[Any, ...]]:
76 with self.db.wrap_database_errors:
77 yield from self.cursor
78
79 def fetchone(self) -> tuple[Any, ...] | None:
80 with self.db.wrap_database_errors:
81 return self.cursor.fetchone()
82
83 def fetchmany(self, size: int | None = None) -> list[tuple[Any, ...]]:
84 with self.db.wrap_database_errors:
85 if size is None:
86 return self.cursor.fetchmany()
87 return self.cursor.fetchmany(size)
88
89 def fetchall(self) -> list[tuple[Any, ...]]:
90 with self.db.wrap_database_errors:
91 return self.cursor.fetchall()
92
93 def __enter__(self) -> Self:
94 return self
95
96 def __exit__(
97 self,
98 type: type[BaseException] | None,
99 value: BaseException | None,
100 traceback: TracebackType | None,
101 ) -> None:
102 # Close instead of passing through to avoid backend-specific behavior
103 # (#17671). Catch errors liberally because errors in cleanup code
104 # aren't useful.
105 try:
106 self.close()
107 except psycopg.Error:
108 pass
109
110 def stream(
111 self, sql: str, params: Sequence[Any] | None = None
112 ) -> Generator[tuple[Any, ...]]:
113 self.db.validate_no_broken_transaction()
114 with db_span(self.db, sql, params=params):
115 with self.db.wrap_database_errors:
116 try:
117 if params is None:
118 yield from self.cursor.stream(sql)
119 else:
120 yield from self.cursor.stream(sql, params)
121 finally:
122 try:
123 self.close()
124 except psycopg.Error:
125 pass
126
127 # The following methods cannot be implemented in __getattr__, because the
128 # code must run when the method is invoked, not just when it is accessed.
129
130 def callproc(
131 self,
132 procname: str,
133 params: Sequence[Any] | None = None,
134 kparams: Mapping[str, Any] | None = None,
135 ) -> Any:
136 # Keyword parameters for callproc aren't supported in PEP 249.
137 # PostgreSQL's psycopg doesn't support them either.
138 if kparams is not None:
139 raise NotSupportedError(
140 "Keyword parameters for callproc are not supported."
141 )
142 self.db.validate_no_broken_transaction()
143 with self.db.wrap_database_errors:
144 if params is None:
145 return self.cursor.callproc(procname)
146 return self.cursor.callproc(procname, params)
147
148 def execute(
149 self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
150 ) -> Self:
151 return self._execute_with_wrappers(
152 sql, params, many=False, executor=self._execute
153 )
154
155 def executemany(self, sql: str, param_list: Sequence[Sequence[Any]]) -> Self:
156 return self._execute_with_wrappers(
157 sql, param_list, many=True, executor=self._executemany
158 )
159
160 def _execute_with_wrappers(
161 self, sql: str, params: Any, many: bool, executor: Any
162 ) -> Self:
163 context: dict[str, Any] = {"connection": self.db, "cursor": self}
164 for wrapper in reversed(self.db.execute_wrappers):
165 executor = functools.partial(wrapper, executor)
166 executor(sql, params, many, context)
167 return self
168
169 def _execute(self, sql: str, params: Any, *ignored_wrapper_args: Any) -> None:
170 # Wrap in an OpenTelemetry span with standard attributes.
171 with db_span(self.db, sql, params=params):
172 self.db.validate_no_broken_transaction()
173 with self.db.wrap_database_errors:
174 if params is None:
175 self.cursor.execute(sql)
176 else:
177 self.cursor.execute(sql, params)
178
179 def _executemany(
180 self, sql: str, param_list: Any, *ignored_wrapper_args: Any
181 ) -> None:
182 with db_span(self.db, sql, many=True, params=param_list):
183 self.db.validate_no_broken_transaction()
184 with self.db.wrap_database_errors:
185 self.cursor.executemany(sql, param_list)
186
187
188class CursorDebugWrapper(CursorWrapper):
189 # XXX callproc isn't instrumented at this time.
190
191 def stream(
192 self, sql: str, params: Sequence[Any] | None = None
193 ) -> Generator[tuple[Any, ...]]:
194 with self.debug_sql(sql, params, use_last_executed_query=True):
195 yield from super().stream(sql, params)
196
197 def execute(
198 self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
199 ) -> Self:
200 with self.debug_sql(sql, params, use_last_executed_query=True):
201 super().execute(sql, params)
202 return self
203
204 def executemany(self, sql: str, param_list: Sequence[Sequence[Any]]) -> Self:
205 with self.debug_sql(sql, param_list, many=True):
206 super().executemany(sql, param_list)
207 return self
208
209 @contextmanager
210 def debug_sql(
211 self,
212 sql: str | None = None,
213 params: Any = None,
214 use_last_executed_query: bool = False,
215 many: bool = False,
216 ) -> Generator[None]:
217 start = time.monotonic()
218 try:
219 yield
220 finally:
221 stop = time.monotonic()
222 duration = stop - start
223 if use_last_executed_query:
224 sql = self.db.last_executed_query(self.cursor, sql, params) # type: ignore[arg-type]
225 try:
226 times = len(params) if many else ""
227 except TypeError:
228 # params could be an iterator.
229 times = "?"
230 self.db.queries_log.append(
231 {
232 "sql": f"{times} times: {sql}" if many else sql,
233 "time": f"{duration:.3f}",
234 }
235 )
236 logger.debug(
237 "(%.3f) %s; args=%s",
238 duration,
239 sql,
240 params,
241 extra={
242 "duration": duration,
243 "sql": sql,
244 "params": params,
245 },
246 )
247
248
249@contextmanager
250def debug_transaction(connection: DatabaseConnection, sql: str) -> Generator[None]:
251 start = time.monotonic()
252 try:
253 yield
254 finally:
255 if connection.queries_logged:
256 stop = time.monotonic()
257 duration = stop - start
258 connection.queries_log.append(
259 {
260 "sql": f"{sql}",
261 "time": f"{duration:.3f}",
262 }
263 )
264 logger.debug(
265 "(%.3f) %s; args=%s",
266 duration,
267 sql,
268 None,
269 extra={
270 "duration": duration,
271 "sql": sql,
272 },
273 )
274
275
276def split_tzname_delta(tzname: str) -> tuple[str, str | None, str | None]:
277 """
278 Split a time zone name into a 3-tuple of (name, sign, offset).
279 """
280 for sign in ["+", "-"]:
281 if sign in tzname:
282 name, offset = tzname.rsplit(sign, 1)
283 if offset and parse_time(offset):
284 return name, sign, offset
285 return tzname, None, None
286
287
288###############################################
289# Converters from Python to database (string) #
290###############################################
291
292
293def split_identifier(identifier: str) -> tuple[str, str]:
294 """
295 Split an SQL identifier into a two element tuple of (namespace, name).
296
297 The identifier could be a table, column, or sequence name might be prefixed
298 by a namespace.
299 """
300 try:
301 namespace, name = identifier.split('"."')
302 except ValueError:
303 namespace, name = "", identifier
304 return namespace.strip('"'), name.strip('"')
305
306
307def truncate_name(identifier: str, length: int | None = None, hash_len: int = 4) -> str:
308 """
309 Shorten an SQL identifier to a repeatable mangled version with the given
310 length.
311
312 If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
313 truncate the table portion only.
314 """
315 namespace, name = split_identifier(identifier)
316
317 if length is None or len(name) <= length:
318 return identifier
319
320 digest = names_digest(name, length=hash_len)
321 return "{}{}{}".format(
322 f'{namespace}"."' if namespace else "",
323 name[: length - hash_len],
324 digest,
325 )
326
327
328def names_digest(*args: str, length: int) -> str:
329 """
330 Generate a 32-bit digest of a set of arguments that can be used to shorten
331 identifying names.
332 """
333 h = md5(usedforsecurity=False)
334 for arg in args:
335 h.update(arg.encode())
336 return h.hexdigest()[:length]
337
338
339def strip_quotes(table_name: str) -> str:
340 """
341 Strip quotes off of quoted table names to make them safe for use in index
342 names, sequence names, etc.
343 """
344 has_quotes = table_name.startswith('"') and table_name.endswith('"')
345 return table_name[1:-1] if has_quotes else table_name