Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2PostgreSQL database backend for Plain.
  3
  4Requires psycopg2 >= 2.8.4 or psycopg >= 3.1.8
  5"""
  6
  7import threading
  8import warnings
  9from contextlib import contextmanager
 10
 11from plain.exceptions import ImproperlyConfigured
 12from plain.models.backends.base.base import BaseDatabaseWrapper
 13from plain.models.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper
 14from plain.models.db import DatabaseError as WrappedDatabaseError
 15from plain.models.db import connections
 16from plain.runtime import settings
 17from plain.utils.functional import cached_property
 18from plain.utils.safestring import SafeString
 19
 20try:
 21    try:
 22        import psycopg as Database
 23    except ImportError:
 24        import psycopg2 as Database
 25except ImportError:
 26    raise ImproperlyConfigured("Error loading psycopg2 or psycopg module")
 27
 28
 29from .psycopg_any import IsolationLevel, is_psycopg3  # NOQA isort:skip
 30
 31if is_psycopg3:
 32    from psycopg import adapters, sql
 33    from psycopg.pq import Format
 34
 35    from .psycopg_any import get_adapters_template, register_tzloader
 36
 37    TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid
 38
 39else:
 40    import psycopg2.extensions
 41    import psycopg2.extras
 42
 43    psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString)
 44    psycopg2.extras.register_uuid()
 45
 46    # Register support for inet[] manually so we don't have to handle the Inet()
 47    # object on load all the time.
 48    INETARRAY_OID = 1041
 49    INETARRAY = psycopg2.extensions.new_array_type(
 50        (INETARRAY_OID,),
 51        "INETARRAY",
 52        psycopg2.extensions.UNICODE,
 53    )
 54    psycopg2.extensions.register_type(INETARRAY)
 55
 56# Some of these import psycopg, so import them after checking if it's installed.
 57from .client import DatabaseClient  # NOQA isort:skip
 58from .creation import DatabaseCreation  # NOQA isort:skip
 59from .features import DatabaseFeatures  # NOQA isort:skip
 60from .introspection import DatabaseIntrospection  # NOQA isort:skip
 61from .operations import DatabaseOperations  # NOQA isort:skip
 62from .schema import DatabaseSchemaEditor  # NOQA isort:skip
 63
 64
 65def _get_varchar_column(data):
 66    if data["max_length"] is None:
 67        return "varchar"
 68    return "varchar({max_length})".format(**data)
 69
 70
 71class DatabaseWrapper(BaseDatabaseWrapper):
 72    vendor = "postgresql"
 73    display_name = "PostgreSQL"
 74    # This dictionary maps Field objects to their associated PostgreSQL column
 75    # types, as strings. Column-type strings can contain format strings; they'll
 76    # be interpolated against the values of Field.__dict__ before being output.
 77    # If a column type is set to None, it won't be included in the output.
 78    data_types = {
 79        "AutoField": "integer",
 80        "BigAutoField": "bigint",
 81        "BinaryField": "bytea",
 82        "BooleanField": "boolean",
 83        "CharField": _get_varchar_column,
 84        "DateField": "date",
 85        "DateTimeField": "timestamp with time zone",
 86        "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
 87        "DurationField": "interval",
 88        "FloatField": "double precision",
 89        "IntegerField": "integer",
 90        "BigIntegerField": "bigint",
 91        "IPAddressField": "inet",
 92        "GenericIPAddressField": "inet",
 93        "JSONField": "jsonb",
 94        "OneToOneField": "integer",
 95        "PositiveBigIntegerField": "bigint",
 96        "PositiveIntegerField": "integer",
 97        "PositiveSmallIntegerField": "smallint",
 98        "SlugField": "varchar(%(max_length)s)",
 99        "SmallAutoField": "smallint",
100        "SmallIntegerField": "smallint",
101        "TextField": "text",
102        "TimeField": "time",
103        "UUIDField": "uuid",
104    }
105    data_type_check_constraints = {
106        "PositiveBigIntegerField": '"%(column)s" >= 0',
107        "PositiveIntegerField": '"%(column)s" >= 0',
108        "PositiveSmallIntegerField": '"%(column)s" >= 0',
109    }
110    data_types_suffix = {
111        "AutoField": "GENERATED BY DEFAULT AS IDENTITY",
112        "BigAutoField": "GENERATED BY DEFAULT AS IDENTITY",
113        "SmallAutoField": "GENERATED BY DEFAULT AS IDENTITY",
114    }
115    operators = {
116        "exact": "= %s",
117        "iexact": "= UPPER(%s)",
118        "contains": "LIKE %s",
119        "icontains": "LIKE UPPER(%s)",
120        "regex": "~ %s",
121        "iregex": "~* %s",
122        "gt": "> %s",
123        "gte": ">= %s",
124        "lt": "< %s",
125        "lte": "<= %s",
126        "startswith": "LIKE %s",
127        "endswith": "LIKE %s",
128        "istartswith": "LIKE UPPER(%s)",
129        "iendswith": "LIKE UPPER(%s)",
130    }
131
132    # The patterns below are used to generate SQL pattern lookup clauses when
133    # the right-hand side of the lookup isn't a raw string (it might be an expression
134    # or the result of a bilateral transformation).
135    # In those cases, special characters for LIKE operators (e.g. \, *, _) should be
136    # escaped on database side.
137    #
138    # Note: we use str.format() here for readability as '%' is used as a wildcard for
139    # the LIKE operator.
140    pattern_esc = (
141        r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')"
142    )
143    pattern_ops = {
144        "contains": "LIKE '%%' || {} || '%%'",
145        "icontains": "LIKE '%%' || UPPER({}) || '%%'",
146        "startswith": "LIKE {} || '%%'",
147        "istartswith": "LIKE UPPER({}) || '%%'",
148        "endswith": "LIKE '%%' || {}",
149        "iendswith": "LIKE '%%' || UPPER({})",
150    }
151
152    Database = Database
153    SchemaEditorClass = DatabaseSchemaEditor
154    # Classes instantiated in __init__().
155    client_class = DatabaseClient
156    creation_class = DatabaseCreation
157    features_class = DatabaseFeatures
158    introspection_class = DatabaseIntrospection
159    ops_class = DatabaseOperations
160    # PostgreSQL backend-specific attributes.
161    _named_cursor_idx = 0
162
163    def get_database_version(self):
164        """
165        Return a tuple of the database's version.
166        E.g. for pg_version 120004, return (12, 4).
167        """
168        return divmod(self.pg_version, 10000)
169
170    def get_connection_params(self):
171        settings_dict = self.settings_dict
172        # None may be used to connect to the default 'postgres' db
173        if settings_dict["NAME"] == "" and not settings_dict.get("OPTIONS", {}).get(
174            "service"
175        ):
176            raise ImproperlyConfigured(
177                "settings.DATABASES is improperly configured. "
178                "Please supply the NAME or OPTIONS['service'] value."
179            )
180        if len(settings_dict["NAME"] or "") > self.ops.max_name_length():
181            raise ImproperlyConfigured(
182                "The database name '%s' (%d characters) is longer than "
183                "PostgreSQL's limit of %d characters. Supply a shorter NAME "
184                "in settings.DATABASES."
185                % (
186                    settings_dict["NAME"],
187                    len(settings_dict["NAME"]),
188                    self.ops.max_name_length(),
189                )
190            )
191        conn_params = {"client_encoding": "UTF8"}
192        if settings_dict["NAME"]:
193            conn_params = {
194                "dbname": settings_dict["NAME"],
195                **settings_dict["OPTIONS"],
196            }
197        elif settings_dict["NAME"] is None:
198            # Connect to the default 'postgres' db.
199            settings_dict.get("OPTIONS", {}).pop("service", None)
200            conn_params = {"dbname": "postgres", **settings_dict["OPTIONS"]}
201        else:
202            conn_params = {**settings_dict["OPTIONS"]}
203
204        conn_params.pop("assume_role", None)
205        conn_params.pop("isolation_level", None)
206        conn_params.pop("server_side_binding", None)
207        if settings_dict["USER"]:
208            conn_params["user"] = settings_dict["USER"]
209        if settings_dict["PASSWORD"]:
210            conn_params["password"] = settings_dict["PASSWORD"]
211        if settings_dict["HOST"]:
212            conn_params["host"] = settings_dict["HOST"]
213        if settings_dict["PORT"]:
214            conn_params["port"] = settings_dict["PORT"]
215        if is_psycopg3:
216            conn_params["context"] = get_adapters_template(
217                settings.USE_TZ, self.timezone
218            )
219            # Disable prepared statements by default to keep connection poolers
220            # working. Can be reenabled via OPTIONS in the settings dict.
221            conn_params["prepare_threshold"] = conn_params.pop(
222                "prepare_threshold", None
223            )
224        return conn_params
225
226    def get_new_connection(self, conn_params):
227        # self.isolation_level must be set:
228        # - after connecting to the database in order to obtain the database's
229        #   default when no value is explicitly specified in options.
230        # - before calling _set_autocommit() because if autocommit is on, that
231        #   will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT.
232        options = self.settings_dict["OPTIONS"]
233        set_isolation_level = False
234        try:
235            isolation_level_value = options["isolation_level"]
236        except KeyError:
237            self.isolation_level = IsolationLevel.READ_COMMITTED
238        else:
239            # Set the isolation level to the value from OPTIONS.
240            try:
241                self.isolation_level = IsolationLevel(isolation_level_value)
242                set_isolation_level = True
243            except ValueError:
244                raise ImproperlyConfigured(
245                    f"Invalid transaction isolation level {isolation_level_value} "
246                    f"specified. Use one of the psycopg.IsolationLevel values."
247                )
248        connection = self.Database.connect(**conn_params)
249        if set_isolation_level:
250            connection.isolation_level = self.isolation_level
251        if is_psycopg3:
252            connection.cursor_factory = (
253                ServerBindingCursor
254                if options.get("server_side_binding") is True
255                else Cursor
256            )
257        else:
258            # Register dummy loads() to avoid a round trip from psycopg2's
259            # decode to json.dumps() to json.loads(), when using a custom
260            # decoder in JSONField.
261            psycopg2.extras.register_default_jsonb(
262                conn_or_curs=connection, loads=lambda x: x
263            )
264            connection.cursor_factory = Cursor
265        return connection
266
267    def ensure_timezone(self):
268        if self.connection is None:
269            return False
270        conn_timezone_name = self.connection.info.parameter_status("TimeZone")
271        timezone_name = self.timezone_name
272        if timezone_name and conn_timezone_name != timezone_name:
273            with self.connection.cursor() as cursor:
274                cursor.execute(self.ops.set_time_zone_sql(), [timezone_name])
275            return True
276        return False
277
278    def ensure_role(self):
279        if self.connection is None:
280            return False
281        if new_role := self.settings_dict.get("OPTIONS", {}).get("assume_role"):
282            with self.connection.cursor() as cursor:
283                sql = self.ops.compose_sql("SET ROLE %s", [new_role])
284                cursor.execute(sql)
285            return True
286        return False
287
288    def init_connection_state(self):
289        super().init_connection_state()
290
291        # Commit after setting the time zone.
292        commit_tz = self.ensure_timezone()
293        # Set the role on the connection. This is useful if the credential used
294        # to login is not the same as the role that owns database resources. As
295        # can be the case when using temporary or ephemeral credentials.
296        commit_role = self.ensure_role()
297
298        if (commit_role or commit_tz) and not self.get_autocommit():
299            self.connection.commit()
300
301    def create_cursor(self, name=None):
302        if name:
303            # In autocommit mode, the cursor will be used outside of a
304            # transaction, hence use a holdable cursor.
305            cursor = self.connection.cursor(
306                name, scrollable=False, withhold=self.connection.autocommit
307            )
308        else:
309            cursor = self.connection.cursor()
310
311        if is_psycopg3:
312            # Register the cursor timezone only if the connection disagrees, to
313            # avoid copying the adapter map.
314            tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
315            if self.timezone != tzloader.timezone:
316                register_tzloader(self.timezone, cursor)
317        else:
318            cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
319        return cursor
320
321    def tzinfo_factory(self, offset):
322        return self.timezone
323
324    def chunked_cursor(self):
325        self._named_cursor_idx += 1
326        # Get the current async task
327        # Note that right now this is behind @async_unsafe, so this is
328        # unreachable, but in future we'll start loosening this restriction.
329        # For now, it's here so that every use of "threading" is
330        # also async-compatible.
331        task_ident = "sync"
332        # Use that and the thread ident to get a unique name
333        return self._cursor(
334            name="_plain_curs_%d_%s_%d"
335            % (
336                # Avoid reusing name in other threads / tasks
337                threading.current_thread().ident,
338                task_ident,
339                self._named_cursor_idx,
340            )
341        )
342
343    def _set_autocommit(self, autocommit):
344        with self.wrap_database_errors:
345            self.connection.autocommit = autocommit
346
347    def check_constraints(self, table_names=None):
348        """
349        Check constraints by setting them to immediate. Return them to deferred
350        afterward.
351        """
352        with self.cursor() as cursor:
353            cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
354            cursor.execute("SET CONSTRAINTS ALL DEFERRED")
355
356    def is_usable(self):
357        try:
358            # Use a psycopg cursor directly, bypassing Plain's utilities.
359            with self.connection.cursor() as cursor:
360                cursor.execute("SELECT 1")
361        except Database.Error:
362            return False
363        else:
364            return True
365
366    @contextmanager
367    def _nodb_cursor(self):
368        cursor = None
369        try:
370            with super()._nodb_cursor() as cursor:
371                yield cursor
372        except (Database.DatabaseError, WrappedDatabaseError):
373            if cursor is not None:
374                raise
375            warnings.warn(
376                "Normally Plain will use a connection to the 'postgres' database "
377                "to avoid running initialization queries against the production "
378                "database when it's not needed (for example, when running tests). "
379                "Plain was unable to create a connection to the 'postgres' database "
380                "and will use the first PostgreSQL database instead.",
381                RuntimeWarning,
382            )
383            for connection in connections.all():
384                if (
385                    connection.vendor == "postgresql"
386                    and connection.settings_dict["NAME"] != "postgres"
387                ):
388                    conn = self.__class__(
389                        {
390                            **self.settings_dict,
391                            "NAME": connection.settings_dict["NAME"],
392                        },
393                        alias=self.alias,
394                    )
395                    try:
396                        with conn.cursor() as cursor:
397                            yield cursor
398                    finally:
399                        conn.close()
400                    break
401            else:
402                raise
403
404    @cached_property
405    def pg_version(self):
406        with self.temporary_connection():
407            return self.connection.info.server_version
408
409    def make_debug_cursor(self, cursor):
410        return CursorDebugWrapper(cursor, self)
411
412
413if is_psycopg3:
414
415    class CursorMixin:
416        """
417        A subclass of psycopg cursor implementing callproc.
418        """
419
420        def callproc(self, name, args=None):
421            if not isinstance(name, sql.Identifier):
422                name = sql.Identifier(name)
423
424            qparts = [sql.SQL("SELECT * FROM "), name, sql.SQL("(")]
425            if args:
426                for item in args:
427                    qparts.append(sql.Literal(item))
428                    qparts.append(sql.SQL(","))
429                del qparts[-1]
430
431            qparts.append(sql.SQL(")"))
432            stmt = sql.Composed(qparts)
433            self.execute(stmt)
434            return args
435
436    class ServerBindingCursor(CursorMixin, Database.Cursor):
437        pass
438
439    class Cursor(CursorMixin, Database.ClientCursor):
440        pass
441
442    class CursorDebugWrapper(BaseCursorDebugWrapper):
443        def copy(self, statement):
444            with self.debug_sql(statement):
445                return self.cursor.copy(statement)
446
447else:
448    Cursor = psycopg2.extensions.cursor
449
450    class CursorDebugWrapper(BaseCursorDebugWrapper):
451        def copy_expert(self, sql, file, *args):
452            with self.debug_sql(sql):
453                return self.cursor.copy_expert(sql, file, *args)
454
455        def copy_to(self, file, table, *args, **kwargs):
456            with self.debug_sql(sql="COPY %s TO STDOUT" % table):
457                return self.cursor.copy_to(file, table, *args, **kwargs)