1"""
2PostgreSQL database backend for Plain.
3
4Requires psycopg2 >= 2.8.4 or psycopg >= 3.1.8
5"""
6
7import threading
8import warnings
9from contextlib import contextmanager
10
11from plain.exceptions import ImproperlyConfigured
12from plain.models.backends.base.base import BaseDatabaseWrapper
13from plain.models.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper
14from plain.models.db import DatabaseError as WrappedDatabaseError
15from plain.models.db import connections
16from plain.runtime import settings
17from plain.utils.functional import cached_property
18from plain.utils.safestring import SafeString
19
20try:
21 try:
22 import psycopg as Database
23 except ImportError:
24 import psycopg2 as Database
25except ImportError:
26 raise ImproperlyConfigured("Error loading psycopg2 or psycopg module")
27
28
29from .psycopg_any import IsolationLevel, is_psycopg3 # NOQA isort:skip
30
31if is_psycopg3:
32 from psycopg import adapters, sql
33 from psycopg.pq import Format
34
35 from .psycopg_any import get_adapters_template, register_tzloader
36
37 TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid
38
39else:
40 import psycopg2.extensions
41 import psycopg2.extras
42
43 psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString)
44 psycopg2.extras.register_uuid()
45
46 # Register support for inet[] manually so we don't have to handle the Inet()
47 # object on load all the time.
48 INETARRAY_OID = 1041
49 INETARRAY = psycopg2.extensions.new_array_type(
50 (INETARRAY_OID,),
51 "INETARRAY",
52 psycopg2.extensions.UNICODE,
53 )
54 psycopg2.extensions.register_type(INETARRAY)
55
56# Some of these import psycopg, so import them after checking if it's installed.
57from .client import DatabaseClient # NOQA isort:skip
58from .creation import DatabaseCreation # NOQA isort:skip
59from .features import DatabaseFeatures # NOQA isort:skip
60from .introspection import DatabaseIntrospection # NOQA isort:skip
61from .operations import DatabaseOperations # NOQA isort:skip
62from .schema import DatabaseSchemaEditor # NOQA isort:skip
63
64
65def _get_varchar_column(data):
66 if data["max_length"] is None:
67 return "varchar"
68 return "varchar({max_length})".format(**data)
69
70
71class DatabaseWrapper(BaseDatabaseWrapper):
72 vendor = "postgresql"
73 display_name = "PostgreSQL"
74 # This dictionary maps Field objects to their associated PostgreSQL column
75 # types, as strings. Column-type strings can contain format strings; they'll
76 # be interpolated against the values of Field.__dict__ before being output.
77 # If a column type is set to None, it won't be included in the output.
78 data_types = {
79 "AutoField": "integer",
80 "BigAutoField": "bigint",
81 "BinaryField": "bytea",
82 "BooleanField": "boolean",
83 "CharField": _get_varchar_column,
84 "DateField": "date",
85 "DateTimeField": "timestamp with time zone",
86 "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
87 "DurationField": "interval",
88 "FloatField": "double precision",
89 "IntegerField": "integer",
90 "BigIntegerField": "bigint",
91 "IPAddressField": "inet",
92 "GenericIPAddressField": "inet",
93 "JSONField": "jsonb",
94 "OneToOneField": "integer",
95 "PositiveBigIntegerField": "bigint",
96 "PositiveIntegerField": "integer",
97 "PositiveSmallIntegerField": "smallint",
98 "SlugField": "varchar(%(max_length)s)",
99 "SmallAutoField": "smallint",
100 "SmallIntegerField": "smallint",
101 "TextField": "text",
102 "TimeField": "time",
103 "UUIDField": "uuid",
104 }
105 data_type_check_constraints = {
106 "PositiveBigIntegerField": '"%(column)s" >= 0',
107 "PositiveIntegerField": '"%(column)s" >= 0',
108 "PositiveSmallIntegerField": '"%(column)s" >= 0',
109 }
110 data_types_suffix = {
111 "AutoField": "GENERATED BY DEFAULT AS IDENTITY",
112 "BigAutoField": "GENERATED BY DEFAULT AS IDENTITY",
113 "SmallAutoField": "GENERATED BY DEFAULT AS IDENTITY",
114 }
115 operators = {
116 "exact": "= %s",
117 "iexact": "= UPPER(%s)",
118 "contains": "LIKE %s",
119 "icontains": "LIKE UPPER(%s)",
120 "regex": "~ %s",
121 "iregex": "~* %s",
122 "gt": "> %s",
123 "gte": ">= %s",
124 "lt": "< %s",
125 "lte": "<= %s",
126 "startswith": "LIKE %s",
127 "endswith": "LIKE %s",
128 "istartswith": "LIKE UPPER(%s)",
129 "iendswith": "LIKE UPPER(%s)",
130 }
131
132 # The patterns below are used to generate SQL pattern lookup clauses when
133 # the right-hand side of the lookup isn't a raw string (it might be an expression
134 # or the result of a bilateral transformation).
135 # In those cases, special characters for LIKE operators (e.g. \, *, _) should be
136 # escaped on database side.
137 #
138 # Note: we use str.format() here for readability as '%' is used as a wildcard for
139 # the LIKE operator.
140 pattern_esc = (
141 r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')"
142 )
143 pattern_ops = {
144 "contains": "LIKE '%%' || {} || '%%'",
145 "icontains": "LIKE '%%' || UPPER({}) || '%%'",
146 "startswith": "LIKE {} || '%%'",
147 "istartswith": "LIKE UPPER({}) || '%%'",
148 "endswith": "LIKE '%%' || {}",
149 "iendswith": "LIKE '%%' || UPPER({})",
150 }
151
152 Database = Database
153 SchemaEditorClass = DatabaseSchemaEditor
154 # Classes instantiated in __init__().
155 client_class = DatabaseClient
156 creation_class = DatabaseCreation
157 features_class = DatabaseFeatures
158 introspection_class = DatabaseIntrospection
159 ops_class = DatabaseOperations
160 # PostgreSQL backend-specific attributes.
161 _named_cursor_idx = 0
162
163 def get_database_version(self):
164 """
165 Return a tuple of the database's version.
166 E.g. for pg_version 120004, return (12, 4).
167 """
168 return divmod(self.pg_version, 10000)
169
170 def get_connection_params(self):
171 settings_dict = self.settings_dict
172 # None may be used to connect to the default 'postgres' db
173 if settings_dict["NAME"] == "" and not settings_dict.get("OPTIONS", {}).get(
174 "service"
175 ):
176 raise ImproperlyConfigured(
177 "settings.DATABASES is improperly configured. "
178 "Please supply the NAME or OPTIONS['service'] value."
179 )
180 if len(settings_dict["NAME"] or "") > self.ops.max_name_length():
181 raise ImproperlyConfigured(
182 "The database name '%s' (%d characters) is longer than "
183 "PostgreSQL's limit of %d characters. Supply a shorter NAME "
184 "in settings.DATABASES."
185 % (
186 settings_dict["NAME"],
187 len(settings_dict["NAME"]),
188 self.ops.max_name_length(),
189 )
190 )
191 conn_params = {"client_encoding": "UTF8"}
192 if settings_dict["NAME"]:
193 conn_params = {
194 "dbname": settings_dict["NAME"],
195 **settings_dict["OPTIONS"],
196 }
197 elif settings_dict["NAME"] is None:
198 # Connect to the default 'postgres' db.
199 settings_dict.get("OPTIONS", {}).pop("service", None)
200 conn_params = {"dbname": "postgres", **settings_dict["OPTIONS"]}
201 else:
202 conn_params = {**settings_dict["OPTIONS"]}
203
204 conn_params.pop("assume_role", None)
205 conn_params.pop("isolation_level", None)
206 conn_params.pop("server_side_binding", None)
207 if settings_dict["USER"]:
208 conn_params["user"] = settings_dict["USER"]
209 if settings_dict["PASSWORD"]:
210 conn_params["password"] = settings_dict["PASSWORD"]
211 if settings_dict["HOST"]:
212 conn_params["host"] = settings_dict["HOST"]
213 if settings_dict["PORT"]:
214 conn_params["port"] = settings_dict["PORT"]
215 if is_psycopg3:
216 conn_params["context"] = get_adapters_template(
217 settings.USE_TZ, self.timezone
218 )
219 # Disable prepared statements by default to keep connection poolers
220 # working. Can be reenabled via OPTIONS in the settings dict.
221 conn_params["prepare_threshold"] = conn_params.pop(
222 "prepare_threshold", None
223 )
224 return conn_params
225
226 def get_new_connection(self, conn_params):
227 # self.isolation_level must be set:
228 # - after connecting to the database in order to obtain the database's
229 # default when no value is explicitly specified in options.
230 # - before calling _set_autocommit() because if autocommit is on, that
231 # will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT.
232 options = self.settings_dict["OPTIONS"]
233 set_isolation_level = False
234 try:
235 isolation_level_value = options["isolation_level"]
236 except KeyError:
237 self.isolation_level = IsolationLevel.READ_COMMITTED
238 else:
239 # Set the isolation level to the value from OPTIONS.
240 try:
241 self.isolation_level = IsolationLevel(isolation_level_value)
242 set_isolation_level = True
243 except ValueError:
244 raise ImproperlyConfigured(
245 f"Invalid transaction isolation level {isolation_level_value} "
246 f"specified. Use one of the psycopg.IsolationLevel values."
247 )
248 connection = self.Database.connect(**conn_params)
249 if set_isolation_level:
250 connection.isolation_level = self.isolation_level
251 if is_psycopg3:
252 connection.cursor_factory = (
253 ServerBindingCursor
254 if options.get("server_side_binding") is True
255 else Cursor
256 )
257 else:
258 # Register dummy loads() to avoid a round trip from psycopg2's
259 # decode to json.dumps() to json.loads(), when using a custom
260 # decoder in JSONField.
261 psycopg2.extras.register_default_jsonb(
262 conn_or_curs=connection, loads=lambda x: x
263 )
264 connection.cursor_factory = Cursor
265 return connection
266
267 def ensure_timezone(self):
268 if self.connection is None:
269 return False
270 conn_timezone_name = self.connection.info.parameter_status("TimeZone")
271 timezone_name = self.timezone_name
272 if timezone_name and conn_timezone_name != timezone_name:
273 with self.connection.cursor() as cursor:
274 cursor.execute(self.ops.set_time_zone_sql(), [timezone_name])
275 return True
276 return False
277
278 def ensure_role(self):
279 if self.connection is None:
280 return False
281 if new_role := self.settings_dict.get("OPTIONS", {}).get("assume_role"):
282 with self.connection.cursor() as cursor:
283 sql = self.ops.compose_sql("SET ROLE %s", [new_role])
284 cursor.execute(sql)
285 return True
286 return False
287
288 def init_connection_state(self):
289 super().init_connection_state()
290
291 # Commit after setting the time zone.
292 commit_tz = self.ensure_timezone()
293 # Set the role on the connection. This is useful if the credential used
294 # to login is not the same as the role that owns database resources. As
295 # can be the case when using temporary or ephemeral credentials.
296 commit_role = self.ensure_role()
297
298 if (commit_role or commit_tz) and not self.get_autocommit():
299 self.connection.commit()
300
301 def create_cursor(self, name=None):
302 if name:
303 # In autocommit mode, the cursor will be used outside of a
304 # transaction, hence use a holdable cursor.
305 cursor = self.connection.cursor(
306 name, scrollable=False, withhold=self.connection.autocommit
307 )
308 else:
309 cursor = self.connection.cursor()
310
311 if is_psycopg3:
312 # Register the cursor timezone only if the connection disagrees, to
313 # avoid copying the adapter map.
314 tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
315 if self.timezone != tzloader.timezone:
316 register_tzloader(self.timezone, cursor)
317 else:
318 cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
319 return cursor
320
321 def tzinfo_factory(self, offset):
322 return self.timezone
323
324 def chunked_cursor(self):
325 self._named_cursor_idx += 1
326 # Get the current async task
327 # Note that right now this is behind @async_unsafe, so this is
328 # unreachable, but in future we'll start loosening this restriction.
329 # For now, it's here so that every use of "threading" is
330 # also async-compatible.
331 task_ident = "sync"
332 # Use that and the thread ident to get a unique name
333 return self._cursor(
334 name="_plain_curs_%d_%s_%d"
335 % (
336 # Avoid reusing name in other threads / tasks
337 threading.current_thread().ident,
338 task_ident,
339 self._named_cursor_idx,
340 )
341 )
342
343 def _set_autocommit(self, autocommit):
344 with self.wrap_database_errors:
345 self.connection.autocommit = autocommit
346
347 def check_constraints(self, table_names=None):
348 """
349 Check constraints by setting them to immediate. Return them to deferred
350 afterward.
351 """
352 with self.cursor() as cursor:
353 cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
354 cursor.execute("SET CONSTRAINTS ALL DEFERRED")
355
356 def is_usable(self):
357 try:
358 # Use a psycopg cursor directly, bypassing Plain's utilities.
359 with self.connection.cursor() as cursor:
360 cursor.execute("SELECT 1")
361 except Database.Error:
362 return False
363 else:
364 return True
365
366 @contextmanager
367 def _nodb_cursor(self):
368 cursor = None
369 try:
370 with super()._nodb_cursor() as cursor:
371 yield cursor
372 except (Database.DatabaseError, WrappedDatabaseError):
373 if cursor is not None:
374 raise
375 warnings.warn(
376 "Normally Plain will use a connection to the 'postgres' database "
377 "to avoid running initialization queries against the production "
378 "database when it's not needed (for example, when running tests). "
379 "Plain was unable to create a connection to the 'postgres' database "
380 "and will use the first PostgreSQL database instead.",
381 RuntimeWarning,
382 )
383 for connection in connections.all():
384 if (
385 connection.vendor == "postgresql"
386 and connection.settings_dict["NAME"] != "postgres"
387 ):
388 conn = self.__class__(
389 {
390 **self.settings_dict,
391 "NAME": connection.settings_dict["NAME"],
392 },
393 alias=self.alias,
394 )
395 try:
396 with conn.cursor() as cursor:
397 yield cursor
398 finally:
399 conn.close()
400 break
401 else:
402 raise
403
404 @cached_property
405 def pg_version(self):
406 with self.temporary_connection():
407 return self.connection.info.server_version
408
409 def make_debug_cursor(self, cursor):
410 return CursorDebugWrapper(cursor, self)
411
412
413if is_psycopg3:
414
415 class CursorMixin:
416 """
417 A subclass of psycopg cursor implementing callproc.
418 """
419
420 def callproc(self, name, args=None):
421 if not isinstance(name, sql.Identifier):
422 name = sql.Identifier(name)
423
424 qparts = [sql.SQL("SELECT * FROM "), name, sql.SQL("(")]
425 if args:
426 for item in args:
427 qparts.append(sql.Literal(item))
428 qparts.append(sql.SQL(","))
429 del qparts[-1]
430
431 qparts.append(sql.SQL(")"))
432 stmt = sql.Composed(qparts)
433 self.execute(stmt)
434 return args
435
436 class ServerBindingCursor(CursorMixin, Database.Cursor):
437 pass
438
439 class Cursor(CursorMixin, Database.ClientCursor):
440 pass
441
442 class CursorDebugWrapper(BaseCursorDebugWrapper):
443 def copy(self, statement):
444 with self.debug_sql(statement):
445 return self.cursor.copy(statement)
446
447else:
448 Cursor = psycopg2.extensions.cursor
449
450 class CursorDebugWrapper(BaseCursorDebugWrapper):
451 def copy_expert(self, sql, file, *args):
452 with self.debug_sql(sql):
453 return self.cursor.copy_expert(sql, file, *args)
454
455 def copy_to(self, file, table, *args, **kwargs):
456 with self.debug_sql(sql="COPY %s TO STDOUT" % table):
457 return self.cursor.copy_to(file, table, *args, **kwargs)