Plain is headed towards 1.0! Subscribe for development updates →

  1import json
  2from functools import lru_cache, partial
  3
  4from plain.models.backends.base.operations import BaseDatabaseOperations
  5from plain.models.backends.postgresql.psycopg_any import (
  6    Inet,
  7    Jsonb,
  8    errors,
  9    is_psycopg3,
 10    mogrify,
 11)
 12from plain.models.backends.utils import split_tzname_delta
 13from plain.models.constants import OnConflict
 14from plain.runtime import settings
 15from plain.utils.regex_helper import _lazy_re_compile
 16
 17
 18@lru_cache
 19def get_json_dumps(encoder):
 20    if encoder is None:
 21        return json.dumps
 22    return partial(json.dumps, cls=encoder)
 23
 24
 25class DatabaseOperations(BaseDatabaseOperations):
 26    cast_char_field_without_max_length = "varchar"
 27    explain_prefix = "EXPLAIN"
 28    explain_options = frozenset(
 29        [
 30            "ANALYZE",
 31            "BUFFERS",
 32            "COSTS",
 33            "SETTINGS",
 34            "SUMMARY",
 35            "TIMING",
 36            "VERBOSE",
 37            "WAL",
 38        ]
 39    )
 40    cast_data_types = {
 41        "AutoField": "integer",
 42        "BigAutoField": "bigint",
 43        "SmallAutoField": "smallint",
 44    }
 45
 46    if is_psycopg3:
 47        from psycopg.types import numeric
 48
 49        integerfield_type_map = {
 50            "SmallIntegerField": numeric.Int2,
 51            "IntegerField": numeric.Int4,
 52            "BigIntegerField": numeric.Int8,
 53            "PositiveSmallIntegerField": numeric.Int2,
 54            "PositiveIntegerField": numeric.Int4,
 55            "PositiveBigIntegerField": numeric.Int8,
 56        }
 57
 58    def unification_cast_sql(self, output_field):
 59        internal_type = output_field.get_internal_type()
 60        if internal_type in (
 61            "GenericIPAddressField",
 62            "IPAddressField",
 63            "TimeField",
 64            "UUIDField",
 65        ):
 66            # PostgreSQL will resolve a union as type 'text' if input types are
 67            # 'unknown'.
 68            # https://www.postgresql.org/docs/current/typeconv-union-case.html
 69            # These fields cannot be implicitly cast back in the default
 70            # PostgreSQL configuration so we need to explicitly cast them.
 71            # We must also remove components of the type within brackets:
 72            # varchar(255) -> varchar.
 73            return (
 74                "CAST(%%s AS %s)" % output_field.db_type(self.connection).split("(")[0]
 75            )
 76        return "%s"
 77
 78    # EXTRACT format cannot be passed in parameters.
 79    _extract_format_re = _lazy_re_compile(r"[A-Z_]+")
 80
 81    def date_extract_sql(self, lookup_type, sql, params):
 82        # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
 83        if lookup_type == "week_day":
 84            # For consistency across backends, we return Sunday=1, Saturday=7.
 85            return f"EXTRACT(DOW FROM {sql}) + 1", params
 86        elif lookup_type == "iso_week_day":
 87            return f"EXTRACT(ISODOW FROM {sql})", params
 88        elif lookup_type == "iso_year":
 89            return f"EXTRACT(ISOYEAR FROM {sql})", params
 90
 91        lookup_type = lookup_type.upper()
 92        if not self._extract_format_re.fullmatch(lookup_type):
 93            raise ValueError(f"Invalid lookup type: {lookup_type!r}")
 94        return f"EXTRACT({lookup_type} FROM {sql})", params
 95
 96    def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
 97        sql, params = self._convert_sql_to_tz(sql, params, tzname)
 98        # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
 99        return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
100
101    def _prepare_tzname_delta(self, tzname):
102        tzname, sign, offset = split_tzname_delta(tzname)
103        if offset:
104            sign = "-" if sign == "+" else "+"
105            return f"{tzname}{sign}{offset}"
106        return tzname
107
108    def _convert_sql_to_tz(self, sql, params, tzname):
109        if tzname and settings.USE_TZ:
110            tzname_param = self._prepare_tzname_delta(tzname)
111            return f"{sql} AT TIME ZONE %s", (*params, tzname_param)
112        return sql, params
113
114    def datetime_cast_date_sql(self, sql, params, tzname):
115        sql, params = self._convert_sql_to_tz(sql, params, tzname)
116        return f"({sql})::date", params
117
118    def datetime_cast_time_sql(self, sql, params, tzname):
119        sql, params = self._convert_sql_to_tz(sql, params, tzname)
120        return f"({sql})::time", params
121
122    def datetime_extract_sql(self, lookup_type, sql, params, tzname):
123        sql, params = self._convert_sql_to_tz(sql, params, tzname)
124        if lookup_type == "second":
125            # Truncate fractional seconds.
126            return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params)
127        return self.date_extract_sql(lookup_type, sql, params)
128
129    def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
130        sql, params = self._convert_sql_to_tz(sql, params, tzname)
131        # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
132        return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
133
134    def time_extract_sql(self, lookup_type, sql, params):
135        if lookup_type == "second":
136            # Truncate fractional seconds.
137            return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params)
138        return self.date_extract_sql(lookup_type, sql, params)
139
140    def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
141        sql, params = self._convert_sql_to_tz(sql, params, tzname)
142        return f"DATE_TRUNC(%s, {sql})::time", (lookup_type, *params)
143
144    def deferrable_sql(self):
145        return " DEFERRABLE INITIALLY DEFERRED"
146
147    def fetch_returned_insert_rows(self, cursor):
148        """
149        Given a cursor object that has just performed an INSERT...RETURNING
150        statement into a table, return the tuple of returned data.
151        """
152        return cursor.fetchall()
153
154    def lookup_cast(self, lookup_type, internal_type=None):
155        lookup = "%s"
156
157        if lookup_type == "isnull" and internal_type in (
158            "CharField",
159            "EmailField",
160            "TextField",
161            "CICharField",
162            "CIEmailField",
163            "CITextField",
164        ):
165            return "%s::text"
166
167        # Cast text lookups to text to allow things like filter(x__contains=4)
168        if lookup_type in (
169            "iexact",
170            "contains",
171            "icontains",
172            "startswith",
173            "istartswith",
174            "endswith",
175            "iendswith",
176            "regex",
177            "iregex",
178        ):
179            if internal_type in ("IPAddressField", "GenericIPAddressField"):
180                lookup = "HOST(%s)"
181            # RemovedInDjango51Warning.
182            elif internal_type in ("CICharField", "CIEmailField", "CITextField"):
183                lookup = "%s::citext"
184            else:
185                lookup = "%s::text"
186
187        # Use UPPER(x) for case-insensitive lookups; it's faster.
188        if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"):
189            lookup = "UPPER(%s)" % lookup
190
191        return lookup
192
193    def no_limit_value(self):
194        return None
195
196    def prepare_sql_script(self, sql):
197        return [sql]
198
199    def quote_name(self, name):
200        if name.startswith('"') and name.endswith('"'):
201            return name  # Quoting once is enough.
202        return '"%s"' % name
203
204    def compose_sql(self, sql, params):
205        return mogrify(sql, params, self.connection)
206
207    def set_time_zone_sql(self):
208        return "SELECT set_config('TimeZone', %s, false)"
209
210    def sequence_reset_by_name_sql(self, style, sequences):
211        # 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements
212        # to reset sequence indices
213        sql = []
214        for sequence_info in sequences:
215            table_name = sequence_info["table"]
216            # 'id' will be the case if it's an m2m using an autogenerated
217            # intermediate table (see BaseDatabaseIntrospection.sequence_list).
218            column_name = sequence_info["column"] or "id"
219            sql.append(
220                "{} setval(pg_get_serial_sequence('{}','{}'), 1, false);".format(
221                    style.SQL_KEYWORD("SELECT"),
222                    style.SQL_TABLE(self.quote_name(table_name)),
223                    style.SQL_FIELD(column_name),
224                )
225            )
226        return sql
227
228    def tablespace_sql(self, tablespace, inline=False):
229        if inline:
230            return "USING INDEX TABLESPACE %s" % self.quote_name(tablespace)
231        else:
232            return "TABLESPACE %s" % self.quote_name(tablespace)
233
234    def sequence_reset_sql(self, style, model_list):
235        from plain import models
236
237        output = []
238        qn = self.quote_name
239        for model in model_list:
240            # Use `coalesce` to set the sequence for each model to the max pk
241            # value if there are records, or 1 if there are none. Set the
242            # `is_called` property (the third argument to `setval`) to true if
243            # there are records (as the max pk value is already in use),
244            # otherwise set it to false. Use pg_get_serial_sequence to get the
245            # underlying sequence name from the table name and column name.
246
247            for f in model._meta.local_fields:
248                if isinstance(f, models.AutoField):
249                    output.append(
250                        "{} setval(pg_get_serial_sequence('{}','{}'), "
251                        "coalesce(max({}), 1), max({}) {} null) {} {};".format(
252                            style.SQL_KEYWORD("SELECT"),
253                            style.SQL_TABLE(qn(model._meta.db_table)),
254                            style.SQL_FIELD(f.column),
255                            style.SQL_FIELD(qn(f.column)),
256                            style.SQL_FIELD(qn(f.column)),
257                            style.SQL_KEYWORD("IS NOT"),
258                            style.SQL_KEYWORD("FROM"),
259                            style.SQL_TABLE(qn(model._meta.db_table)),
260                        )
261                    )
262                    # Only one AutoField is allowed per model, so don't bother
263                    # continuing.
264                    break
265        return output
266
267    def prep_for_iexact_query(self, x):
268        return x
269
270    def max_name_length(self):
271        """
272        Return the maximum length of an identifier.
273
274        The maximum length of an identifier is 63 by default, but can be
275        changed by recompiling PostgreSQL after editing the NAMEDATALEN
276        macro in src/include/pg_config_manual.h.
277
278        This implementation returns 63, but can be overridden by a custom
279        database backend that inherits most of its behavior from this one.
280        """
281        return 63
282
283    def distinct_sql(self, fields, params):
284        if fields:
285            params = [param for param_list in params for param in param_list]
286            return (["DISTINCT ON (%s)" % ", ".join(fields)], params)
287        else:
288            return ["DISTINCT"], []
289
290    if is_psycopg3:
291
292        def last_executed_query(self, cursor, sql, params):
293            try:
294                return self.compose_sql(sql, params)
295            except errors.DataError:
296                return None
297
298    else:
299
300        def last_executed_query(self, cursor, sql, params):
301            # https://www.psycopg.org/docs/cursor.html#cursor.query
302            # The query attribute is a Psycopg extension to the DB API 2.0.
303            if cursor.query is not None:
304                return cursor.query.decode()
305            return None
306
307    def return_insert_columns(self, fields):
308        if not fields:
309            return "", ()
310        columns = [
311            "{}.{}".format(
312                self.quote_name(field.model._meta.db_table),
313                self.quote_name(field.column),
314            )
315            for field in fields
316        ]
317        return "RETURNING %s" % ", ".join(columns), ()
318
319    def bulk_insert_sql(self, fields, placeholder_rows):
320        placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
321        values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
322        return "VALUES " + values_sql
323
324    if is_psycopg3:
325
326        def adapt_integerfield_value(self, value, internal_type):
327            if value is None or hasattr(value, "resolve_expression"):
328                return value
329            return self.integerfield_type_map[internal_type](value)
330
331    def adapt_datefield_value(self, value):
332        return value
333
334    def adapt_datetimefield_value(self, value):
335        return value
336
337    def adapt_timefield_value(self, value):
338        return value
339
340    def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
341        return value
342
343    def adapt_ipaddressfield_value(self, value):
344        if value:
345            return Inet(value)
346        return None
347
348    def adapt_json_value(self, value, encoder):
349        return Jsonb(value, dumps=get_json_dumps(encoder))
350
351    def subtract_temporals(self, internal_type, lhs, rhs):
352        if internal_type == "DateField":
353            lhs_sql, lhs_params = lhs
354            rhs_sql, rhs_params = rhs
355            params = (*lhs_params, *rhs_params)
356            return f"(interval '1 day' * ({lhs_sql} - {rhs_sql}))", params
357        return super().subtract_temporals(internal_type, lhs, rhs)
358
359    def explain_query_prefix(self, format=None, **options):
360        extra = {}
361        # Normalize options.
362        if options:
363            options = {
364                name.upper(): "true" if value else "false"
365                for name, value in options.items()
366            }
367            for valid_option in self.explain_options:
368                value = options.pop(valid_option, None)
369                if value is not None:
370                    extra[valid_option] = value
371        prefix = super().explain_query_prefix(format, **options)
372        if format:
373            extra["FORMAT"] = format
374        if extra:
375            prefix += " (%s)" % ", ".join("{} {}".format(*i) for i in extra.items())
376        return prefix
377
378    def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
379        if on_conflict == OnConflict.IGNORE:
380            return "ON CONFLICT DO NOTHING"
381        if on_conflict == OnConflict.UPDATE:
382            return "ON CONFLICT({}) DO UPDATE SET {}".format(
383                ", ".join(map(self.quote_name, unique_fields)),
384                ", ".join(
385                    [
386                        f"{field} = EXCLUDED.{field}"
387                        for field in map(self.quote_name, update_fields)
388                    ]
389                ),
390            )
391        return super().on_conflict_suffix_sql(
392            fields,
393            on_conflict,
394            update_fields,
395            unique_fields,
396        )