1from __future__ import annotations
2
3import _thread
4import datetime
5import logging
6import os
7import signal
8import subprocess
9import sys
10import threading
11import time
12import warnings
13import zoneinfo
14from collections import deque
15from collections.abc import Generator, Sequence
16from contextlib import contextmanager
17from functools import cached_property, lru_cache
18from typing import TYPE_CHECKING, Any, LiteralString, NamedTuple, cast
19
20import psycopg as Database
21from psycopg import ClientCursor, IsolationLevel, adapt, adapters, errors
22from psycopg import sql as psycopg_sql
23from psycopg.abc import Buffer, PyFormat
24from psycopg.postgres import types as pg_types
25from psycopg.pq import Format
26from psycopg.types.datetime import TimestamptzLoader
27from psycopg.types.range import BaseRangeDumper, Range, RangeDumper
28from psycopg.types.string import TextLoader
29
30from plain.exceptions import ImproperlyConfigured
31from plain.models.db import (
32 DatabaseError,
33 DatabaseErrorWrapper,
34 NotSupportedError,
35 db_connection,
36)
37from plain.models.db import DatabaseError as WrappedDatabaseError
38from plain.models.indexes import Index
39from plain.models.postgres import utils
40from plain.models.postgres.schema import DatabaseSchemaEditor
41from plain.models.postgres.sql import MAX_NAME_LENGTH, quote_name
42from plain.models.postgres.utils import CursorDebugWrapper as BaseCursorDebugWrapper
43from plain.models.postgres.utils import CursorWrapper, debug_transaction
44from plain.models.transaction import TransactionManagementError
45from plain.runtime import settings
46
47if TYPE_CHECKING:
48 from psycopg import Connection as PsycopgConnection
49
50 from plain.models.connections import DatabaseConfig
51 from plain.models.fields import Field
52
53RAN_DB_VERSION_CHECK = False
54
55logger = logging.getLogger("plain.models.postgres")
56
57# The prefix to put on the default database name when creating
58# the test database.
59TEST_DATABASE_PREFIX = "test_"
60
61
62def get_migratable_models() -> Generator[Any, None, None]:
63 """Return all models that should be included in migrations."""
64 from plain.models import models_registry
65 from plain.packages import packages_registry
66
67 return (
68 model
69 for package_config in packages_registry.get_package_configs()
70 for model in models_registry.get_models(
71 package_label=package_config.package_label
72 )
73 )
74
75
76class TableInfo(NamedTuple):
77 """Structure returned by DatabaseWrapper.get_table_list()."""
78
79 name: str
80 type: str
81 comment: str | None
82
83
84# Type OIDs
85TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid
86TSRANGE_OID = pg_types["tsrange"].oid
87TSTZRANGE_OID = pg_types["tstzrange"].oid
88
89
90class BaseTzLoader(TimestamptzLoader):
91 """
92 Load a PostgreSQL timestamptz using a specific timezone.
93 The timezone can be None too, in which case it will be chopped.
94 """
95
96 timezone: datetime.tzinfo | None = None
97
98 def load(self, data: Buffer) -> datetime.datetime:
99 res = super().load(data)
100 return res.replace(tzinfo=self.timezone)
101
102
103def register_tzloader(tz: datetime.tzinfo | None, context: Any) -> None:
104 class SpecificTzLoader(BaseTzLoader):
105 timezone = tz
106
107 context.adapters.register_loader("timestamptz", SpecificTzLoader)
108
109
110class PlainRangeDumper(RangeDumper):
111 """A Range dumper customized for Plain."""
112
113 def upgrade(self, obj: Range[Any], format: PyFormat) -> BaseRangeDumper:
114 dumper = super().upgrade(obj, format)
115 if dumper is not self and dumper.oid == TSRANGE_OID:
116 dumper.oid = TSTZRANGE_OID
117 return dumper
118
119
120@lru_cache
121def get_adapters_template(timezone: datetime.tzinfo | None) -> adapt.AdaptersMap:
122 ctx = adapt.AdaptersMap(adapters)
123 # No-op JSON loader to avoid psycopg3 round trips
124 ctx.register_loader("jsonb", TextLoader)
125 # Treat inet/cidr as text
126 ctx.register_loader("inet", TextLoader)
127 ctx.register_loader("cidr", TextLoader)
128 ctx.register_dumper(Range, PlainRangeDumper)
129 register_tzloader(timezone, ctx)
130 return ctx
131
132
133def _psql_settings_to_cmd_args_env(
134 settings_dict: DatabaseConfig, parameters: list[str]
135) -> tuple[list[str], dict[str, str] | None]:
136 """Build psql command-line arguments from database settings."""
137 args = ["psql"]
138 options = settings_dict.get("OPTIONS", {})
139
140 host = settings_dict.get("HOST")
141 port = settings_dict.get("PORT")
142 dbname = settings_dict.get("NAME")
143 user = settings_dict.get("USER")
144 passwd = settings_dict.get("PASSWORD")
145 passfile = options.get("passfile")
146 service = options.get("service")
147 sslmode = options.get("sslmode")
148 sslrootcert = options.get("sslrootcert")
149 sslcert = options.get("sslcert")
150 sslkey = options.get("sslkey")
151
152 if not dbname and not service:
153 # Connect to the default 'postgres' db.
154 dbname = "postgres"
155 if user:
156 args += ["-U", user]
157 if host:
158 args += ["-h", host]
159 if port:
160 args += ["-p", str(port)]
161 args.extend(parameters)
162 if dbname:
163 args += [dbname]
164
165 env = {}
166 if passwd:
167 env["PGPASSWORD"] = str(passwd)
168 if service:
169 env["PGSERVICE"] = str(service)
170 if sslmode:
171 env["PGSSLMODE"] = str(sslmode)
172 if sslrootcert:
173 env["PGSSLROOTCERT"] = str(sslrootcert)
174 if sslcert:
175 env["PGSSLCERT"] = str(sslcert)
176 if sslkey:
177 env["PGSSLKEY"] = str(sslkey)
178 if passfile:
179 env["PGPASSFILE"] = str(passfile)
180 return args, (env or None)
181
182
183class DatabaseWrapper:
184 """
185 PostgreSQL database connection wrapper.
186
187 This is the only database backend supported by Plain.
188 """
189
190 queries_limit: int = 9000
191 executable_name: str = "psql"
192
193 index_default_access_method = "btree"
194 ignored_tables: list[str] = []
195
196 # PostgreSQL backend-specific attributes.
197 _named_cursor_idx = 0
198
199 def __init__(self, settings_dict: DatabaseConfig):
200 # Connection related attributes.
201 # The underlying database connection (from the database library, not a wrapper).
202 self.connection: PsycopgConnection[Any] | None = None
203 # `settings_dict` should be a dictionary containing keys such as
204 # NAME, USER, etc. It's called `settings_dict` instead of `settings`
205 # to disambiguate it from Plain settings modules.
206 self.settings_dict: DatabaseConfig = settings_dict
207 # Query logging in debug mode or when explicitly enabled.
208 self.queries_log: deque[dict[str, Any]] = deque(maxlen=self.queries_limit)
209 self.force_debug_cursor: bool = False
210
211 # Transaction related attributes.
212 # Tracks if the connection is in autocommit mode. Per PEP 249, by
213 # default, it isn't.
214 self.autocommit: bool = False
215 # Tracks if the connection is in a transaction managed by 'atomic'.
216 self.in_atomic_block: bool = False
217 # Increment to generate unique savepoint ids.
218 self.savepoint_state: int = 0
219 # List of savepoints created by 'atomic'.
220 self.savepoint_ids: list[str] = []
221 # Stack of active 'atomic' blocks.
222 self.atomic_blocks: list[Any] = []
223 # Tracks if the outermost 'atomic' block should commit on exit,
224 # ie. if autocommit was active on entry.
225 self.commit_on_exit: bool = True
226 # Tracks if the transaction should be rolled back to the next
227 # available savepoint because of an exception in an inner block.
228 self.needs_rollback: bool = False
229 self.rollback_exc: Exception | None = None
230
231 # Connection termination related attributes.
232 self.close_at: float | None = None
233 self.closed_in_transaction: bool = False
234 self.errors_occurred: bool = False
235 self.health_check_enabled: bool = False
236 self.health_check_done: bool = False
237
238 # Thread-safety related attributes.
239 self._thread_sharing_lock: threading.Lock = threading.Lock()
240 self._thread_sharing_count: int = 0
241 self._thread_ident: int = _thread.get_ident()
242
243 # A list of no-argument functions to run when the transaction commits.
244 # Each entry is an (sids, func, robust) tuple, where sids is a set of
245 # the active savepoint IDs when this function was registered and robust
246 # specifies whether it's allowed for the function to fail.
247 self.run_on_commit: list[tuple[set[str], Any, bool]] = []
248
249 # Should we run the on-commit hooks the next time set_autocommit(True)
250 # is called?
251 self.run_commit_hooks_on_set_autocommit_on: bool = False
252
253 # A stack of wrappers to be invoked around execute()/executemany()
254 # calls. Each entry is a function taking five arguments: execute, sql,
255 # params, many, and context. It's the function's responsibility to
256 # call execute(sql, params, many, context).
257 self.execute_wrappers: list[Any] = []
258
259 def __repr__(self) -> str:
260 return f"<{self.__class__.__qualname__} vendor='postgresql'>"
261
262 @cached_property
263 def timezone(self) -> datetime.tzinfo:
264 """
265 Return a tzinfo of the database connection time zone.
266
267 When a datetime is read from the database, it is returned in this time
268 zone. Since PostgreSQL supports time zones, it doesn't matter which
269 time zone Plain uses, as long as aware datetimes are used everywhere.
270 Other users connecting to the database can choose their own time zone.
271 """
272 if self.settings_dict["TIME_ZONE"] is None:
273 return datetime.UTC
274 else:
275 return zoneinfo.ZoneInfo(self.settings_dict["TIME_ZONE"])
276
277 @cached_property
278 def timezone_name(self) -> str:
279 """
280 Name of the time zone of the database connection.
281 """
282 if self.settings_dict["TIME_ZONE"] is None:
283 return "UTC"
284 else:
285 return self.settings_dict["TIME_ZONE"]
286
287 @property
288 def queries_logged(self) -> bool:
289 return self.force_debug_cursor or settings.DEBUG
290
291 @property
292 def queries(self) -> list[dict[str, Any]]:
293 if len(self.queries_log) == self.queries_log.maxlen:
294 warnings.warn(
295 f"Limit for query logging exceeded, only the last {self.queries_log.maxlen} queries "
296 "will be returned."
297 )
298 return list(self.queries_log)
299
300 def check_database_version_supported(self) -> None:
301 """
302 Raise an error if the database version isn't supported by this
303 version of Plain. PostgreSQL 12+ is required.
304 """
305 major, minor = divmod(self.pg_version, 10000)
306 if major < 12:
307 raise NotSupportedError(
308 f"PostgreSQL 12 or later is required (found {major}.{minor})."
309 )
310
311 # ##### Connection and cursor methods #####
312
313 def get_connection_params(self) -> dict[str, Any]:
314 """Return a dict of parameters suitable for get_new_connection."""
315 settings_dict = self.settings_dict
316 options = settings_dict.get("OPTIONS", {})
317 # None may be used to connect to the default 'postgres' db
318 if settings_dict.get("NAME") == "" and not options.get("service"):
319 raise ImproperlyConfigured(
320 "settings.DATABASE is improperly configured. "
321 "Please supply the NAME or OPTIONS['service'] value."
322 )
323 db_name = settings_dict.get("NAME")
324 if len(db_name or "") > MAX_NAME_LENGTH:
325 raise ImproperlyConfigured(
326 "The database name '%s' (%d characters) is longer than " # noqa: UP031
327 "PostgreSQL's limit of %d characters. Supply a shorter NAME "
328 "in settings.DATABASE."
329 % (
330 db_name,
331 len(db_name or ""),
332 MAX_NAME_LENGTH,
333 )
334 )
335 conn_params: dict[str, Any] = {"client_encoding": "UTF8"}
336 if db_name:
337 conn_params = {
338 "dbname": db_name,
339 **options,
340 }
341 elif db_name is None:
342 # Connect to the default 'postgres' db.
343 options.pop("service", None)
344 conn_params = {"dbname": "postgres", **options}
345 else:
346 conn_params = {**options}
347
348 conn_params.pop("assume_role", None)
349 conn_params.pop("isolation_level", None)
350 conn_params.pop("server_side_binding", None)
351 if settings_dict["USER"]:
352 conn_params["user"] = settings_dict["USER"]
353 if settings_dict["PASSWORD"]:
354 conn_params["password"] = settings_dict["PASSWORD"]
355 if settings_dict["HOST"]:
356 conn_params["host"] = settings_dict["HOST"]
357 if settings_dict["PORT"]:
358 conn_params["port"] = settings_dict["PORT"]
359 conn_params["context"] = get_adapters_template(self.timezone)
360 # Disable prepared statements by default to keep connection poolers
361 # working. Can be reenabled via OPTIONS in the settings dict.
362 conn_params["prepare_threshold"] = conn_params.pop("prepare_threshold", None)
363 return conn_params
364
365 def get_new_connection(self, conn_params: dict[str, Any]) -> PsycopgConnection[Any]:
366 """Open a connection to the database."""
367 # self.isolation_level must be set:
368 # - after connecting to the database in order to obtain the database's
369 # default when no value is explicitly specified in options.
370 # - before calling _set_autocommit() because if autocommit is on, that
371 # will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT.
372 options = self.settings_dict.get("OPTIONS", {})
373 set_isolation_level = False
374 try:
375 isolation_level_value = options["isolation_level"]
376 except KeyError:
377 self.isolation_level = IsolationLevel.READ_COMMITTED
378 else:
379 # Set the isolation level to the value from OPTIONS.
380 try:
381 self.isolation_level = IsolationLevel(isolation_level_value)
382 set_isolation_level = True
383 except ValueError:
384 raise ImproperlyConfigured(
385 f"Invalid transaction isolation level {isolation_level_value} "
386 f"specified. Use one of the psycopg.IsolationLevel values."
387 )
388 connection = Database.connect(**conn_params)
389 if set_isolation_level:
390 connection.isolation_level = self.isolation_level
391 # Use server-side binding cursor if requested, otherwise standard cursor
392 connection.cursor_factory = (
393 ServerBindingCursor
394 if options.get("server_side_binding") is True
395 else Cursor
396 )
397 return connection
398
399 def ensure_timezone(self) -> bool:
400 """
401 Ensure the connection's timezone is set to `self.timezone_name` and
402 return whether it changed or not.
403 """
404 if self.connection is None:
405 return False
406 conn_timezone_name = self.connection.info.parameter_status("TimeZone")
407 timezone_name = self.timezone_name
408 if timezone_name and conn_timezone_name != timezone_name:
409 with self.connection.cursor() as cursor:
410 cursor.execute(
411 "SELECT set_config('TimeZone', %s, false)", [timezone_name]
412 )
413 return True
414 return False
415
416 def ensure_role(self) -> bool:
417 if self.connection is None:
418 return False
419 if new_role := self.settings_dict.get("OPTIONS", {}).get("assume_role"):
420 with self.connection.cursor() as cursor:
421 sql_str = self.compose_sql("SET ROLE %s", [new_role])
422 cursor.execute(sql_str) # type: ignore[arg-type]
423 return True
424 return False
425
426 def init_connection_state(self) -> None:
427 """Initialize the database connection settings."""
428 global RAN_DB_VERSION_CHECK
429 if not RAN_DB_VERSION_CHECK:
430 self.check_database_version_supported()
431 RAN_DB_VERSION_CHECK = True
432
433 # Commit after setting the time zone.
434 commit_tz = self.ensure_timezone()
435 # Set the role on the connection. This is useful if the credential used
436 # to login is not the same as the role that owns database resources. As
437 # can be the case when using temporary or ephemeral credentials.
438 commit_role = self.ensure_role()
439
440 if (commit_role or commit_tz) and not self.get_autocommit():
441 assert self.connection is not None
442 self.connection.commit()
443
444 def create_cursor(self, name: str | None = None) -> Any:
445 """Create a cursor. Assume that a connection is established."""
446 assert self.connection is not None
447 if name:
448 # In autocommit mode, the cursor will be used outside of a
449 # transaction, hence use a holdable cursor.
450 cursor = self.connection.cursor(
451 name, scrollable=False, withhold=self.connection.autocommit
452 )
453 else:
454 cursor = self.connection.cursor()
455
456 # Register the cursor timezone only if the connection disagrees, to avoid copying the adapter map.
457 tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
458 if self.timezone != tzloader.timezone: # type: ignore[union-attr]
459 register_tzloader(self.timezone, cursor)
460 return cursor
461
462 def chunked_cursor(self) -> utils.CursorWrapper:
463 """
464 Return a server-side cursor that avoids caching results in memory.
465 """
466 self._named_cursor_idx += 1
467 # Get the current async task
468 # Note that right now this is behind @async_unsafe, so this is
469 # unreachable, but in future we'll start loosening this restriction.
470 # For now, it's here so that every use of "threading" is
471 # also async-compatible.
472 task_ident = "sync"
473 # Use that and the thread ident to get a unique name
474 return self._cursor(
475 name="_plain_curs_%d_%s_%d" # noqa: UP031
476 % (
477 # Avoid reusing name in other threads / tasks
478 threading.current_thread().ident,
479 task_ident,
480 self._named_cursor_idx,
481 )
482 )
483
484 def _set_autocommit(self, autocommit: bool) -> None:
485 """Backend-specific implementation to enable or disable autocommit."""
486 assert self.connection is not None
487 with self.wrap_database_errors:
488 self.connection.autocommit = autocommit
489
490 def check_constraints(self, table_names: list[str] | None = None) -> None:
491 """
492 Check constraints by setting them to immediate. Return them to deferred
493 afterward.
494 """
495 with self.cursor() as cursor:
496 cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
497 cursor.execute("SET CONSTRAINTS ALL DEFERRED")
498
499 def is_usable(self) -> bool:
500 """
501 Test if the database connection is usable.
502
503 This method may assume that self.connection is not None.
504
505 Actual implementations should take care not to raise exceptions
506 as that may prevent Plain from recycling unusable connections.
507 """
508 assert self.connection is not None
509 try:
510 # Use a psycopg cursor directly, bypassing Plain's utilities.
511 with self.connection.cursor() as cursor:
512 cursor.execute("SELECT 1")
513 except Database.Error:
514 return False
515 else:
516 return True
517
518 @contextmanager
519 def _nodb_cursor(self) -> Generator[utils.CursorWrapper, None, None]:
520 """
521 Return a cursor from an alternative connection to be used when there is
522 no need to access the main database, specifically for test db
523 creation/deletion. This also prevents the production database from
524 being exposed to potential child threads while (or after) the test
525 database is destroyed. Refs #10868, #17786, #16969.
526 """
527 cursor = None
528 try:
529 conn = self.__class__({**self.settings_dict, "NAME": None})
530 try:
531 with conn.cursor() as cursor:
532 yield cursor
533 finally:
534 conn.close()
535 except (Database.DatabaseError, WrappedDatabaseError):
536 if cursor is not None:
537 raise
538 warnings.warn(
539 "Normally Plain will use a connection to the 'postgres' database "
540 "to avoid running initialization queries against the production "
541 "database when it's not needed (for example, when running tests). "
542 "Plain was unable to create a connection to the 'postgres' database "
543 "and will use the first PostgreSQL database instead.",
544 RuntimeWarning,
545 )
546 conn = self.__class__(
547 {
548 **self.settings_dict,
549 "NAME": db_connection.settings_dict["NAME"],
550 },
551 )
552 try:
553 with conn.cursor() as cursor:
554 yield cursor
555 finally:
556 conn.close()
557
558 @cached_property
559 def pg_version(self) -> int:
560 with self.temporary_connection():
561 assert self.connection is not None
562 return self.connection.info.server_version
563
564 def make_debug_cursor(self, cursor: Any) -> CursorDebugWrapper:
565 return CursorDebugWrapper(cursor, self)
566
567 # ##### Connection lifecycle #####
568
569 def connect(self) -> None:
570 """Connect to the database. Assume that the connection is closed."""
571 # In case the previous connection was closed while in an atomic block
572 self.in_atomic_block = False
573 self.savepoint_ids = []
574 self.atomic_blocks = []
575 self.needs_rollback = False
576 # Reset parameters defining when to close/health-check the connection.
577 self.health_check_enabled = self.settings_dict["CONN_HEALTH_CHECKS"]
578 max_age = self.settings_dict["CONN_MAX_AGE"]
579 self.close_at = None if max_age is None else time.monotonic() + max_age
580 self.closed_in_transaction = False
581 self.errors_occurred = False
582 # New connections are healthy.
583 self.health_check_done = True
584 # Establish the connection
585 conn_params = self.get_connection_params()
586 self.connection = self.get_new_connection(conn_params)
587 self.set_autocommit(self.settings_dict["AUTOCOMMIT"])
588 self.init_connection_state()
589
590 self.run_on_commit = []
591
592 def ensure_connection(self) -> None:
593 """Guarantee that a connection to the database is established."""
594 if self.connection is None:
595 with self.wrap_database_errors:
596 self.connect()
597
598 # ##### PEP-249 connection method wrappers #####
599
600 def _prepare_cursor(self, cursor: Any) -> utils.CursorWrapper:
601 """
602 Validate the connection is usable and perform database cursor wrapping.
603 """
604 self.validate_thread_sharing()
605 if self.queries_logged:
606 wrapped_cursor = self.make_debug_cursor(cursor)
607 else:
608 wrapped_cursor = self.make_cursor(cursor)
609 return wrapped_cursor
610
611 def _cursor(self, name: str | None = None) -> utils.CursorWrapper:
612 self.close_if_health_check_failed()
613 self.ensure_connection()
614 with self.wrap_database_errors:
615 return self._prepare_cursor(self.create_cursor(name))
616
617 def _commit(self) -> None:
618 if self.connection is not None:
619 with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
620 return self.connection.commit()
621
622 def _rollback(self) -> None:
623 if self.connection is not None:
624 with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
625 return self.connection.rollback()
626
627 def _close(self) -> None:
628 if self.connection is not None:
629 with self.wrap_database_errors:
630 return self.connection.close()
631
632 # ##### Generic wrappers for PEP-249 connection methods #####
633
634 def cursor(self) -> utils.CursorWrapper:
635 """Create a cursor, opening a connection if necessary."""
636 return self._cursor()
637
638 def commit(self) -> None:
639 """Commit a transaction and reset the dirty flag."""
640 self.validate_thread_sharing()
641 self.validate_no_atomic_block()
642 self._commit()
643 # A successful commit means that the database connection works.
644 self.errors_occurred = False
645 self.run_commit_hooks_on_set_autocommit_on = True
646
647 def rollback(self) -> None:
648 """Roll back a transaction and reset the dirty flag."""
649 self.validate_thread_sharing()
650 self.validate_no_atomic_block()
651 self._rollback()
652 # A successful rollback means that the database connection works.
653 self.errors_occurred = False
654 self.needs_rollback = False
655 self.run_on_commit = []
656
657 def close(self) -> None:
658 """Close the connection to the database."""
659 self.validate_thread_sharing()
660 self.run_on_commit = []
661
662 # Don't call validate_no_atomic_block() to avoid making it difficult
663 # to get rid of a connection in an invalid state. The next connect()
664 # will reset the transaction state anyway.
665 if self.closed_in_transaction or self.connection is None:
666 return
667 try:
668 self._close()
669 finally:
670 if self.in_atomic_block:
671 self.closed_in_transaction = True
672 self.needs_rollback = True
673 else:
674 self.connection = None
675
676 # ##### Savepoint management #####
677
678 def _savepoint(self, sid: str) -> None:
679 with self.cursor() as cursor:
680 cursor.execute(f"SAVEPOINT {quote_name(sid)}")
681
682 def _savepoint_rollback(self, sid: str) -> None:
683 with self.cursor() as cursor:
684 cursor.execute(f"ROLLBACK TO SAVEPOINT {quote_name(sid)}")
685
686 def _savepoint_commit(self, sid: str) -> None:
687 with self.cursor() as cursor:
688 cursor.execute(f"RELEASE SAVEPOINT {quote_name(sid)}")
689
690 # ##### Generic savepoint management methods #####
691
692 def savepoint(self) -> str | None:
693 """
694 Create a savepoint inside the current transaction. Return an
695 identifier for the savepoint that will be used for the subsequent
696 rollback or commit. Return None if in autocommit mode (no transaction).
697 """
698 if self.get_autocommit():
699 return None
700
701 thread_ident = _thread.get_ident()
702 tid = str(thread_ident).replace("-", "")
703
704 self.savepoint_state += 1
705 sid = "s%s_x%d" % (tid, self.savepoint_state) # noqa: UP031
706
707 self.validate_thread_sharing()
708 self._savepoint(sid)
709
710 return sid
711
712 def savepoint_rollback(self, sid: str) -> None:
713 """
714 Roll back to a savepoint. Do nothing if in autocommit mode.
715 """
716 if self.get_autocommit():
717 return
718
719 self.validate_thread_sharing()
720 self._savepoint_rollback(sid)
721
722 # Remove any callbacks registered while this savepoint was active.
723 self.run_on_commit = [
724 (sids, func, robust)
725 for (sids, func, robust) in self.run_on_commit
726 if sid not in sids
727 ]
728
729 def savepoint_commit(self, sid: str) -> None:
730 """
731 Release a savepoint. Do nothing if in autocommit mode.
732 """
733 if self.get_autocommit():
734 return
735
736 self.validate_thread_sharing()
737 self._savepoint_commit(sid)
738
739 def clean_savepoints(self) -> None:
740 """
741 Reset the counter used to generate unique savepoint ids in this thread.
742 """
743 self.savepoint_state = 0
744
745 # ##### Generic transaction management methods #####
746
747 def get_autocommit(self) -> bool:
748 """Get the autocommit state."""
749 self.ensure_connection()
750 return self.autocommit
751
752 def set_autocommit(self, autocommit: bool) -> None:
753 """Enable or disable autocommit."""
754 self.validate_no_atomic_block()
755 self.close_if_health_check_failed()
756 self.ensure_connection()
757
758 if autocommit:
759 self._set_autocommit(autocommit)
760 else:
761 with debug_transaction(self, "BEGIN"):
762 self._set_autocommit(autocommit)
763 self.autocommit = autocommit
764
765 if autocommit and self.run_commit_hooks_on_set_autocommit_on:
766 self.run_and_clear_commit_hooks()
767 self.run_commit_hooks_on_set_autocommit_on = False
768
769 def get_rollback(self) -> bool:
770 """Get the "needs rollback" flag -- for *advanced use* only."""
771 if not self.in_atomic_block:
772 raise TransactionManagementError(
773 "The rollback flag doesn't work outside of an 'atomic' block."
774 )
775 return self.needs_rollback
776
777 def set_rollback(self, rollback: bool) -> None:
778 """
779 Set or unset the "needs rollback" flag -- for *advanced use* only.
780 """
781 if not self.in_atomic_block:
782 raise TransactionManagementError(
783 "The rollback flag doesn't work outside of an 'atomic' block."
784 )
785 self.needs_rollback = rollback
786
787 def validate_no_atomic_block(self) -> None:
788 """Raise an error if an atomic block is active."""
789 if self.in_atomic_block:
790 raise TransactionManagementError(
791 "This is forbidden when an 'atomic' block is active."
792 )
793
794 def validate_no_broken_transaction(self) -> None:
795 if self.needs_rollback:
796 raise TransactionManagementError(
797 "An error occurred in the current transaction. You can't "
798 "execute queries until the end of the 'atomic' block."
799 ) from self.rollback_exc
800
801 # ##### Connection termination handling #####
802
803 def close_if_health_check_failed(self) -> None:
804 """Close existing connection if it fails a health check."""
805 if (
806 self.connection is None
807 or not self.health_check_enabled
808 or self.health_check_done
809 ):
810 return
811
812 if not self.is_usable():
813 self.close()
814 self.health_check_done = True
815
816 def close_if_unusable_or_obsolete(self) -> None:
817 """
818 Close the current connection if unrecoverable errors have occurred
819 or if it outlived its maximum age.
820 """
821 if self.connection is not None:
822 self.health_check_done = False
823 # If the application didn't restore the original autocommit setting,
824 # don't take chances, drop the connection.
825 if self.get_autocommit() != self.settings_dict["AUTOCOMMIT"]:
826 self.close()
827 return
828
829 # If an exception other than DataError or IntegrityError occurred
830 # since the last commit / rollback, check if the connection works.
831 if self.errors_occurred:
832 if self.is_usable():
833 self.errors_occurred = False
834 self.health_check_done = True
835 else:
836 self.close()
837 return
838
839 if self.close_at is not None and time.monotonic() >= self.close_at:
840 self.close()
841 return
842
843 # ##### Thread safety handling #####
844
845 @property
846 def allow_thread_sharing(self) -> bool:
847 with self._thread_sharing_lock:
848 return self._thread_sharing_count > 0
849
850 def validate_thread_sharing(self) -> None:
851 """
852 Validate that the connection isn't accessed by another thread than the
853 one which originally created it. Raise an exception if the validation
854 fails.
855 """
856 if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()):
857 raise DatabaseError(
858 "DatabaseWrapper objects created in a "
859 "thread can only be used in that same thread. The connection "
860 f"was created in thread id {self._thread_ident} and this is "
861 f"thread id {_thread.get_ident()}."
862 )
863
864 # ##### Miscellaneous #####
865
866 @cached_property
867 def wrap_database_errors(self) -> DatabaseErrorWrapper:
868 """
869 Context manager and decorator that re-throws backend-specific database
870 exceptions using Plain's common wrappers.
871 """
872 return DatabaseErrorWrapper(self)
873
874 def make_cursor(self, cursor: Any) -> utils.CursorWrapper:
875 """Create a cursor without debug logging."""
876 return utils.CursorWrapper(cursor, self)
877
878 @contextmanager
879 def temporary_connection(self) -> Generator[utils.CursorWrapper, None, None]:
880 """
881 Context manager that ensures that a connection is established, and
882 if it opened one, closes it to avoid leaving a dangling connection.
883 This is useful for operations outside of the request-response cycle.
884
885 Provide a cursor: with self.temporary_connection() as cursor: ...
886 """
887 must_close = self.connection is None
888 try:
889 with self.cursor() as cursor:
890 yield cursor
891 finally:
892 if must_close:
893 self.close()
894
895 def schema_editor(self, *args: Any, **kwargs: Any) -> DatabaseSchemaEditor:
896 """Return a new instance of the schema editor."""
897 return DatabaseSchemaEditor(self, *args, **kwargs)
898
899 def runshell(self, parameters: list[str]) -> None:
900 """Run an interactive psql shell."""
901 args, env = _psql_settings_to_cmd_args_env(self.settings_dict, parameters)
902 env = {**os.environ, **env} if env else None
903 sigint_handler = signal.getsignal(signal.SIGINT)
904 try:
905 # Allow SIGINT to pass to psql to abort queries.
906 signal.signal(signal.SIGINT, signal.SIG_IGN)
907 subprocess.run(args, env=env, check=True)
908 finally:
909 # Restore the original SIGINT handler.
910 signal.signal(signal.SIGINT, sigint_handler)
911
912 def on_commit(self, func: Any, robust: bool = False) -> None:
913 if not callable(func):
914 raise TypeError("on_commit()'s callback must be a callable.")
915 if self.in_atomic_block:
916 # Transaction in progress; save for execution on commit.
917 self.run_on_commit.append((set(self.savepoint_ids), func, robust))
918 elif not self.get_autocommit():
919 raise TransactionManagementError(
920 "on_commit() cannot be used in manual transaction management"
921 )
922 else:
923 # No transaction in progress and in autocommit mode; execute
924 # immediately.
925 if robust:
926 try:
927 func()
928 except Exception as e:
929 logger.error(
930 f"Error calling {func.__qualname__} in on_commit() (%s).",
931 e,
932 exc_info=True,
933 )
934 else:
935 func()
936
937 def run_and_clear_commit_hooks(self) -> None:
938 self.validate_no_atomic_block()
939 current_run_on_commit = self.run_on_commit
940 self.run_on_commit = []
941 while current_run_on_commit:
942 _, func, robust = current_run_on_commit.pop(0)
943 if robust:
944 try:
945 func()
946 except Exception as e:
947 logger.error(
948 f"Error calling {func.__qualname__} in on_commit() during "
949 f"transaction (%s).",
950 e,
951 exc_info=True,
952 )
953 else:
954 func()
955
956 @contextmanager
957 def execute_wrapper(self, wrapper: Any) -> Generator[None, None, None]:
958 """
959 Return a context manager under which the wrapper is applied to suitable
960 database query executions.
961 """
962 self.execute_wrappers.append(wrapper)
963 try:
964 yield
965 finally:
966 self.execute_wrappers.pop()
967
968 # ##### SQL generation methods that require connection state #####
969
970 def compose_sql(self, query: str, params: Any) -> str:
971 """
972 Compose a SQL query with parameters using psycopg's mogrify.
973
974 This requires an active connection because it uses the connection's
975 cursor to properly format parameters.
976 """
977 assert self.connection is not None
978 return ClientCursor(self.connection).mogrify(
979 psycopg_sql.SQL(cast(LiteralString, query)), params
980 )
981
982 def last_executed_query(
983 self,
984 cursor: utils.CursorWrapper,
985 sql: str,
986 params: Any,
987 ) -> str | None:
988 """
989 Return a string of the query last executed by the given cursor, with
990 placeholders replaced with actual values.
991 """
992 try:
993 return self.compose_sql(sql, params)
994 except errors.DataError:
995 return None
996
997 def unification_cast_sql(self, output_field: Field) -> str:
998 """
999 Given a field instance, return the SQL that casts the result of a union
1000 to that type. The resulting string should contain a '%s' placeholder
1001 for the expression being cast.
1002 """
1003 internal_type = output_field.get_internal_type()
1004 if internal_type in (
1005 "GenericIPAddressField",
1006 "TimeField",
1007 "UUIDField",
1008 ):
1009 # PostgreSQL will resolve a union as type 'text' if input types are
1010 # 'unknown'.
1011 # https://www.postgresql.org/docs/current/typeconv-union-case.html
1012 # These fields cannot be implicitly cast back in the default
1013 # PostgreSQL configuration so we need to explicitly cast them.
1014 # We must also remove components of the type within brackets:
1015 # varchar(255) -> varchar.
1016 db_type = output_field.db_type()
1017 if db_type:
1018 return "CAST(%s AS {})".format(db_type.split("(")[0])
1019 return "%s"
1020
1021 # ##### Introspection methods #####
1022
1023 def table_names(
1024 self, cursor: CursorWrapper | None = None, include_views: bool = False
1025 ) -> list[str]:
1026 """
1027 Return a list of names of all tables that exist in the database.
1028 Sort the returned table list by Python's default sorting. Do NOT use
1029 the database's ORDER BY here to avoid subtle differences in sorting
1030 order between databases.
1031 """
1032
1033 def get_names(cursor: CursorWrapper) -> list[str]:
1034 return sorted(
1035 ti.name
1036 for ti in self.get_table_list(cursor)
1037 if include_views or ti.type == "t"
1038 )
1039
1040 if cursor is None:
1041 with self.cursor() as cursor:
1042 return get_names(cursor)
1043 return get_names(cursor)
1044
1045 def get_table_list(self, cursor: CursorWrapper) -> Sequence[TableInfo]:
1046 """
1047 Return an unsorted list of TableInfo named tuples of all tables and
1048 views that exist in the database.
1049 """
1050 cursor.execute(
1051 """
1052 SELECT
1053 c.relname,
1054 CASE
1055 WHEN c.relispartition THEN 'p'
1056 WHEN c.relkind IN ('m', 'v') THEN 'v'
1057 ELSE 't'
1058 END,
1059 obj_description(c.oid, 'pg_class')
1060 FROM pg_catalog.pg_class c
1061 LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
1062 WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
1063 AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
1064 AND pg_catalog.pg_table_is_visible(c.oid)
1065 """
1066 )
1067 return [
1068 TableInfo(*row)
1069 for row in cursor.fetchall()
1070 if row[0] not in self.ignored_tables
1071 ]
1072
1073 def plain_table_names(
1074 self, only_existing: bool = False, include_views: bool = True
1075 ) -> list[str]:
1076 """
1077 Return a list of all table names that have associated Plain models and
1078 are in INSTALLED_PACKAGES.
1079
1080 If only_existing is True, include only the tables in the database.
1081 """
1082 tables = set()
1083 for model in get_migratable_models():
1084 tables.add(model.model_options.db_table)
1085 tables.update(
1086 f.m2m_db_table() for f in model._model_meta.local_many_to_many
1087 )
1088 tables = list(tables)
1089 if only_existing:
1090 existing_tables = set(self.table_names(include_views=include_views))
1091 tables = [t for t in tables if t in existing_tables]
1092 return tables
1093
1094 def get_sequences(
1095 self, cursor: CursorWrapper, table_name: str, table_fields: tuple[Any, ...] = ()
1096 ) -> list[dict[str, Any]]:
1097 """
1098 Return a list of introspected sequences for table_name. Each sequence
1099 is a dict: {'table': <table_name>, 'column': <column_name>, 'name': <sequence_name>}.
1100 """
1101 cursor.execute(
1102 """
1103 SELECT
1104 s.relname AS sequence_name,
1105 a.attname AS colname
1106 FROM
1107 pg_class s
1108 JOIN pg_depend d ON d.objid = s.oid
1109 AND d.classid = 'pg_class'::regclass
1110 AND d.refclassid = 'pg_class'::regclass
1111 JOIN pg_attribute a ON d.refobjid = a.attrelid
1112 AND d.refobjsubid = a.attnum
1113 JOIN pg_class tbl ON tbl.oid = d.refobjid
1114 AND tbl.relname = %s
1115 AND pg_catalog.pg_table_is_visible(tbl.oid)
1116 WHERE
1117 s.relkind = 'S';
1118 """,
1119 [table_name],
1120 )
1121 return [
1122 {"name": row[0], "table": table_name, "column": row[1]}
1123 for row in cursor.fetchall()
1124 ]
1125
1126 def get_constraints(
1127 self, cursor: CursorWrapper, table_name: str
1128 ) -> dict[str, dict[str, Any]]:
1129 """
1130 Retrieve any constraints or keys (unique, pk, fk, check, index) across
1131 one or more columns. Also retrieve the definition of expression-based
1132 indexes.
1133 """
1134 constraints: dict[str, dict[str, Any]] = {}
1135 # Loop over the key table, collecting things as constraints. The column
1136 # array must return column names in the same order in which they were
1137 # created.
1138 cursor.execute(
1139 """
1140 SELECT
1141 c.conname,
1142 array(
1143 SELECT attname
1144 FROM unnest(c.conkey) WITH ORDINALITY cols(colid, arridx)
1145 JOIN pg_attribute AS ca ON cols.colid = ca.attnum
1146 WHERE ca.attrelid = c.conrelid
1147 ORDER BY cols.arridx
1148 ),
1149 c.contype,
1150 (SELECT fkc.relname || '.' || fka.attname
1151 FROM pg_attribute AS fka
1152 JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
1153 WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]),
1154 cl.reloptions
1155 FROM pg_constraint AS c
1156 JOIN pg_class AS cl ON c.conrelid = cl.oid
1157 WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
1158 """,
1159 [table_name],
1160 )
1161 for constraint, columns, kind, used_cols, options in cursor.fetchall():
1162 constraints[constraint] = {
1163 "columns": columns,
1164 "primary_key": kind == "p",
1165 "unique": kind in ["p", "u"],
1166 "foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None,
1167 "check": kind == "c",
1168 "index": False,
1169 "definition": None,
1170 "options": options,
1171 }
1172 # Now get indexes
1173 cursor.execute(
1174 """
1175 SELECT
1176 indexname,
1177 array_agg(attname ORDER BY arridx),
1178 indisunique,
1179 indisprimary,
1180 array_agg(ordering ORDER BY arridx),
1181 amname,
1182 exprdef,
1183 s2.attoptions
1184 FROM (
1185 SELECT
1186 c2.relname as indexname, idx.*, attr.attname, am.amname,
1187 CASE
1188 WHEN idx.indexprs IS NOT NULL THEN
1189 pg_get_indexdef(idx.indexrelid)
1190 END AS exprdef,
1191 CASE am.amname
1192 WHEN %s THEN
1193 CASE (option & 1)
1194 WHEN 1 THEN 'DESC' ELSE 'ASC'
1195 END
1196 END as ordering,
1197 c2.reloptions as attoptions
1198 FROM (
1199 SELECT *
1200 FROM
1201 pg_index i,
1202 unnest(i.indkey, i.indoption)
1203 WITH ORDINALITY koi(key, option, arridx)
1204 ) idx
1205 LEFT JOIN pg_class c ON idx.indrelid = c.oid
1206 LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
1207 LEFT JOIN pg_am am ON c2.relam = am.oid
1208 LEFT JOIN
1209 pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
1210 WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
1211 ) s2
1212 GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
1213 """,
1214 [self.index_default_access_method, table_name],
1215 )
1216 for (
1217 index,
1218 columns,
1219 unique,
1220 primary,
1221 orders,
1222 type_,
1223 definition,
1224 options,
1225 ) in cursor.fetchall():
1226 if index not in constraints:
1227 basic_index = (
1228 type_ == self.index_default_access_method and options is None
1229 )
1230 constraints[index] = {
1231 "columns": columns if columns != [None] else [],
1232 "orders": orders if orders != [None] else [],
1233 "primary_key": primary,
1234 "unique": unique,
1235 "foreign_key": None,
1236 "check": False,
1237 "index": True,
1238 "type": Index.suffix if basic_index else type_,
1239 "definition": definition,
1240 "options": options,
1241 }
1242 return constraints
1243
1244 # ##### Test database creation methods (merged from DatabaseCreation) #####
1245
1246 def _log(self, msg: str) -> None:
1247 sys.stderr.write(msg + os.linesep)
1248
1249 def create_test_db(self, verbosity: int = 1, prefix: str = "") -> str:
1250 """
1251 Create a test database, prompting the user for confirmation if the
1252 database already exists. Return the name of the test database created.
1253
1254 If prefix is provided, it will be prepended to the database name
1255 to isolate it from other test databases.
1256 """
1257 from plain.models.cli.migrations import apply
1258
1259 test_database_name = self._get_test_db_name(prefix)
1260
1261 if verbosity >= 1:
1262 self._log(f"Creating test database '{test_database_name}'...")
1263
1264 self._create_test_db(
1265 test_database_name=test_database_name, verbosity=verbosity, autoclobber=True
1266 )
1267
1268 self.close()
1269 settings.DATABASE["NAME"] = test_database_name
1270 self.settings_dict["NAME"] = test_database_name
1271
1272 apply.callback(
1273 package_label=None,
1274 migration_name=None,
1275 fake=False,
1276 plan=False,
1277 check_unapplied=False,
1278 backup=False,
1279 no_input=True,
1280 atomic_batch=False, # No need for atomic batch when creating test database
1281 quiet=verbosity < 2, # Show migration output when verbosity is 2+
1282 )
1283
1284 # Ensure a connection for the side effect of initializing the test database.
1285 self.ensure_connection()
1286
1287 return test_database_name
1288
1289 def _get_test_db_name(self, prefix: str = "") -> str:
1290 """
1291 Internal implementation - return the name of the test DB that will be
1292 created. Only useful when called from create_test_db() and
1293 _create_test_db() and when no external munging is done with the 'NAME'
1294 settings.
1295
1296 If prefix is provided, it will be prepended to the database name.
1297 """
1298 # Determine the base name: explicit TEST.NAME overrides base NAME.
1299 base_name = self.settings_dict["TEST"]["NAME"] or self.settings_dict["NAME"]
1300 if prefix:
1301 return f"{prefix}_{base_name}"
1302 if self.settings_dict["TEST"]["NAME"]:
1303 return self.settings_dict["TEST"]["NAME"]
1304 name = self.settings_dict["NAME"]
1305 assert name is not None, "DATABASE NAME must be set"
1306 return TEST_DATABASE_PREFIX + name
1307
1308 def _get_database_create_suffix(
1309 self, encoding: str | None = None, template: str | None = None
1310 ) -> str:
1311 """Return PostgreSQL-specific CREATE DATABASE suffix."""
1312 suffix = ""
1313 if encoding:
1314 suffix += f" ENCODING '{encoding}'"
1315 if template:
1316 suffix += f" TEMPLATE {quote_name(template)}"
1317 return suffix and "WITH" + suffix
1318
1319 def _execute_create_test_db(self, cursor: Any, parameters: dict[str, str]) -> None:
1320 try:
1321 cursor.execute("CREATE DATABASE {dbname} {suffix}".format(**parameters))
1322 except Exception as e:
1323 cause = e.__cause__
1324 if cause and not isinstance(cause, errors.DuplicateDatabase):
1325 # All errors except "database already exists" cancel tests.
1326 self._log(f"Got an error creating the test database: {e}")
1327 sys.exit(2)
1328 else:
1329 raise
1330
1331 def _create_test_db(
1332 self, *, test_database_name: str, verbosity: int, autoclobber: bool
1333 ) -> str:
1334 """
1335 Internal implementation - create the test db tables.
1336 """
1337 test_db_params = {
1338 "dbname": quote_name(test_database_name),
1339 "suffix": self.sql_table_creation_suffix(),
1340 }
1341 # Create the test database and connect to it.
1342 with self._nodb_cursor() as cursor:
1343 try:
1344 self._execute_create_test_db(cursor, test_db_params)
1345 except Exception as e:
1346 self._log(f"Got an error creating the test database: {e}")
1347 if not autoclobber:
1348 confirm = input(
1349 "Type 'yes' if you would like to try deleting the test "
1350 f"database '{test_database_name}', or 'no' to cancel: "
1351 )
1352 if autoclobber or confirm == "yes":
1353 try:
1354 if verbosity >= 1:
1355 self._log(
1356 f"Destroying old test database '{test_database_name}'..."
1357 )
1358 cursor.execute(
1359 "DROP DATABASE {dbname}".format(**test_db_params)
1360 )
1361 self._execute_create_test_db(cursor, test_db_params)
1362 except Exception as e:
1363 self._log(f"Got an error recreating the test database: {e}")
1364 sys.exit(2)
1365 else:
1366 self._log("Tests cancelled.")
1367 sys.exit(1)
1368
1369 return test_database_name
1370
1371 def destroy_test_db(
1372 self, old_database_name: str | None = None, verbosity: int = 1
1373 ) -> None:
1374 """
1375 Destroy a test database, prompting the user for confirmation if the
1376 database already exists.
1377 """
1378 self.close()
1379
1380 test_database_name = self.settings_dict["NAME"]
1381 assert test_database_name is not None, "Test database NAME must be set"
1382
1383 if verbosity >= 1:
1384 self._log(f"Destroying test database '{test_database_name}'...")
1385 self._destroy_test_db(test_database_name, verbosity)
1386
1387 # Restore the original database name
1388 if old_database_name is not None:
1389 settings.DATABASE["NAME"] = old_database_name
1390 self.settings_dict["NAME"] = old_database_name
1391
1392 def _destroy_test_db(self, test_database_name: str, verbosity: int) -> None:
1393 """
1394 Internal implementation - remove the test db tables.
1395 """
1396 # Remove the test database to clean up after
1397 # ourselves. Connect to the previous database (not the test database)
1398 # to do so, because it's not allowed to delete a database while being
1399 # connected to it.
1400 with self._nodb_cursor() as cursor:
1401 cursor.execute(f"DROP DATABASE {quote_name(test_database_name)}")
1402
1403 def sql_table_creation_suffix(self) -> str:
1404 """
1405 SQL to append to the end of the test table creation statements.
1406 """
1407 test_settings = self.settings_dict["TEST"]
1408 return self._get_database_create_suffix(
1409 encoding=test_settings.get("CHARSET"),
1410 template=test_settings.get("TEMPLATE"),
1411 )
1412
1413
1414class CursorMixin:
1415 """
1416 A subclass of psycopg cursor implementing callproc.
1417 """
1418
1419 def callproc(
1420 self, name: str | psycopg_sql.Identifier, args: list[Any] | None = None
1421 ) -> list[Any] | None:
1422 if not isinstance(name, psycopg_sql.Identifier):
1423 name = psycopg_sql.Identifier(name)
1424
1425 qparts: list[psycopg_sql.Composable] = [
1426 psycopg_sql.SQL("SELECT * FROM "),
1427 name,
1428 psycopg_sql.SQL("("),
1429 ]
1430 if args:
1431 for item in args:
1432 qparts.append(psycopg_sql.Literal(item))
1433 qparts.append(psycopg_sql.SQL(","))
1434 del qparts[-1]
1435
1436 qparts.append(psycopg_sql.SQL(")"))
1437 stmt = psycopg_sql.Composed(qparts)
1438 self.execute(stmt) # type: ignore[attr-defined]
1439 return args
1440
1441
1442class ServerBindingCursor(CursorMixin, Database.Cursor):
1443 pass
1444
1445
1446class Cursor(CursorMixin, Database.ClientCursor):
1447 pass
1448
1449
1450class CursorDebugWrapper(BaseCursorDebugWrapper):
1451 def copy(self, statement: Any) -> Any:
1452 with self.debug_sql(statement):
1453 return self.cursor.copy(statement)