1from __future__ import annotations
2
3import functools
4import logging
5import time
6from collections.abc import Generator, Iterator, Mapping, Sequence
7from contextlib import contextmanager
8from hashlib import md5
9from types import TracebackType
10from typing import TYPE_CHECKING, Any, Self
11
12import psycopg
13
14from plain.models.db import NotSupportedError
15from plain.models.otel import db_span
16from plain.utils.dateparse import parse_time
17
18if TYPE_CHECKING:
19 from plain.models.postgres.wrapper import DatabaseWrapper
20
21logger = logging.getLogger("plain.models.postgres")
22
23
24class CursorWrapper:
25 def __init__(self, cursor: Any, db: DatabaseWrapper) -> None:
26 self.cursor = cursor
27 self.db = db
28
29 WRAP_ERROR_ATTRS = frozenset(["nextset"])
30
31 def __getattr__(self, attr: str) -> Any:
32 cursor_attr = getattr(self.cursor, attr)
33 if attr in CursorWrapper.WRAP_ERROR_ATTRS:
34 return self.db.wrap_database_errors(cursor_attr)
35 else:
36 return cursor_attr
37
38 def __iter__(self) -> Iterator[tuple[Any, ...]]:
39 with self.db.wrap_database_errors:
40 yield from self.cursor
41
42 def fetchone(self) -> tuple[Any, ...] | None:
43 with self.db.wrap_database_errors:
44 return self.cursor.fetchone()
45
46 def fetchmany(self, size: int | None = None) -> list[tuple[Any, ...]]:
47 with self.db.wrap_database_errors:
48 if size is None:
49 return self.cursor.fetchmany()
50 return self.cursor.fetchmany(size)
51
52 def fetchall(self) -> list[tuple[Any, ...]]:
53 with self.db.wrap_database_errors:
54 return self.cursor.fetchall()
55
56 def __enter__(self) -> Self:
57 return self
58
59 def __exit__(
60 self,
61 type: type[BaseException] | None,
62 value: BaseException | None,
63 traceback: TracebackType | None,
64 ) -> None:
65 # Close instead of passing through to avoid backend-specific behavior
66 # (#17671). Catch errors liberally because errors in cleanup code
67 # aren't useful.
68 try:
69 self.close()
70 except psycopg.Error:
71 pass
72
73 # The following methods cannot be implemented in __getattr__, because the
74 # code must run when the method is invoked, not just when it is accessed.
75
76 def callproc(
77 self,
78 procname: str,
79 params: Sequence[Any] | None = None,
80 kparams: Mapping[str, Any] | None = None,
81 ) -> Any:
82 # Keyword parameters for callproc aren't supported in PEP 249.
83 # PostgreSQL's psycopg doesn't support them either.
84 if kparams is not None:
85 raise NotSupportedError(
86 "Keyword parameters for callproc are not supported."
87 )
88 self.db.validate_no_broken_transaction()
89 with self.db.wrap_database_errors:
90 if params is None:
91 return self.cursor.callproc(procname)
92 return self.cursor.callproc(procname, params)
93
94 def execute(
95 self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
96 ) -> Self:
97 return self._execute_with_wrappers(
98 sql, params, many=False, executor=self._execute
99 )
100
101 def executemany(self, sql: str, param_list: Sequence[Sequence[Any]]) -> Self:
102 return self._execute_with_wrappers(
103 sql, param_list, many=True, executor=self._executemany
104 )
105
106 def _execute_with_wrappers(
107 self, sql: str, params: Any, many: bool, executor: Any
108 ) -> Self:
109 context: dict[str, Any] = {"connection": self.db, "cursor": self}
110 for wrapper in reversed(self.db.execute_wrappers):
111 executor = functools.partial(wrapper, executor)
112 executor(sql, params, many, context)
113 return self
114
115 def _execute(self, sql: str, params: Any, *ignored_wrapper_args: Any) -> None:
116 # Wrap in an OpenTelemetry span with standard attributes.
117 with db_span(self.db, sql, params=params):
118 self.db.validate_no_broken_transaction()
119 with self.db.wrap_database_errors:
120 if params is None:
121 self.cursor.execute(sql)
122 else:
123 self.cursor.execute(sql, params)
124
125 def _executemany(
126 self, sql: str, param_list: Any, *ignored_wrapper_args: Any
127 ) -> None:
128 with db_span(self.db, sql, many=True, params=param_list):
129 self.db.validate_no_broken_transaction()
130 with self.db.wrap_database_errors:
131 self.cursor.executemany(sql, param_list)
132
133
134class CursorDebugWrapper(CursorWrapper):
135 # XXX callproc isn't instrumented at this time.
136
137 def execute(
138 self, sql: str, params: Sequence[Any] | Mapping[str, Any] | None = None
139 ) -> Self:
140 with self.debug_sql(sql, params, use_last_executed_query=True):
141 super().execute(sql, params)
142 return self
143
144 def executemany(self, sql: str, param_list: Sequence[Sequence[Any]]) -> Self:
145 with self.debug_sql(sql, param_list, many=True):
146 super().executemany(sql, param_list)
147 return self
148
149 @contextmanager
150 def debug_sql(
151 self,
152 sql: str | None = None,
153 params: Any = None,
154 use_last_executed_query: bool = False,
155 many: bool = False,
156 ) -> Generator[None, None, None]:
157 start = time.monotonic()
158 try:
159 yield
160 finally:
161 stop = time.monotonic()
162 duration = stop - start
163 if use_last_executed_query:
164 sql = self.db.last_executed_query(self.cursor, sql, params) # type: ignore[arg-type]
165 try:
166 times = len(params) if many else ""
167 except TypeError:
168 # params could be an iterator.
169 times = "?"
170 self.db.queries_log.append(
171 {
172 "sql": f"{times} times: {sql}" if many else sql,
173 "time": f"{duration:.3f}",
174 }
175 )
176 logger.debug(
177 "(%.3f) %s; args=%s",
178 duration,
179 sql,
180 params,
181 extra={
182 "duration": duration,
183 "sql": sql,
184 "params": params,
185 },
186 )
187
188
189@contextmanager
190def debug_transaction(
191 connection: DatabaseWrapper, sql: str
192) -> Generator[None, None, None]:
193 start = time.monotonic()
194 try:
195 yield
196 finally:
197 if connection.queries_logged:
198 stop = time.monotonic()
199 duration = stop - start
200 connection.queries_log.append(
201 {
202 "sql": f"{sql}",
203 "time": f"{duration:.3f}",
204 }
205 )
206 logger.debug(
207 "(%.3f) %s; args=%s",
208 duration,
209 sql,
210 None,
211 extra={
212 "duration": duration,
213 "sql": sql,
214 },
215 )
216
217
218def split_tzname_delta(tzname: str) -> tuple[str, str | None, str | None]:
219 """
220 Split a time zone name into a 3-tuple of (name, sign, offset).
221 """
222 for sign in ["+", "-"]:
223 if sign in tzname:
224 name, offset = tzname.rsplit(sign, 1)
225 if offset and parse_time(offset):
226 return name, sign, offset
227 return tzname, None, None
228
229
230###############################################
231# Converters from Python to database (string) #
232###############################################
233
234
235def split_identifier(identifier: str) -> tuple[str, str]:
236 """
237 Split an SQL identifier into a two element tuple of (namespace, name).
238
239 The identifier could be a table, column, or sequence name might be prefixed
240 by a namespace.
241 """
242 try:
243 namespace, name = identifier.split('"."')
244 except ValueError:
245 namespace, name = "", identifier
246 return namespace.strip('"'), name.strip('"')
247
248
249def truncate_name(identifier: str, length: int | None = None, hash_len: int = 4) -> str:
250 """
251 Shorten an SQL identifier to a repeatable mangled version with the given
252 length.
253
254 If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
255 truncate the table portion only.
256 """
257 namespace, name = split_identifier(identifier)
258
259 if length is None or len(name) <= length:
260 return identifier
261
262 digest = names_digest(name, length=hash_len)
263 return "{}{}{}".format(
264 f'{namespace}"."' if namespace else "",
265 name[: length - hash_len],
266 digest,
267 )
268
269
270def names_digest(*args: str, length: int) -> str:
271 """
272 Generate a 32-bit digest of a set of arguments that can be used to shorten
273 identifying names.
274 """
275 h = md5(usedforsecurity=False)
276 for arg in args:
277 h.update(arg.encode())
278 return h.hexdigest()[:length]
279
280
281def strip_quotes(table_name: str) -> str:
282 """
283 Strip quotes off of quoted table names to make them safe for use in index
284 names, sequence names, etc.
285 """
286 has_quotes = table_name.startswith('"') and table_name.endswith('"')
287 return table_name[1:-1] if has_quotes else table_name