Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import datetime
  4import uuid
  5from collections.abc import Iterable
  6from typing import TYPE_CHECKING, Any
  7
  8from plain.models.backends.base.operations import BaseDatabaseOperations
  9from plain.models.backends.utils import CursorWrapper, split_tzname_delta
 10from plain.models.constants import OnConflict
 11from plain.models.expressions import Exists, ExpressionWrapper, ResolvableExpression
 12from plain.models.lookups import Lookup
 13from plain.utils import timezone
 14from plain.utils.encoding import force_str
 15from plain.utils.regex_helper import _lazy_re_compile
 16
 17if TYPE_CHECKING:
 18    from plain.models.backends.base.base import BaseDatabaseWrapper
 19    from plain.models.backends.mysql.base import MySQLDatabaseWrapper
 20    from plain.models.fields import Field
 21
 22
 23class DatabaseOperations(BaseDatabaseOperations):
 24    # Type checker hint: connection is always MySQLDatabaseWrapper in this class
 25    connection: MySQLDatabaseWrapper
 26
 27    compiler_module = "plain.models.backends.mysql.compiler"
 28
 29    # MySQL stores positive fields as UNSIGNED ints.
 30    integer_field_ranges = {
 31        **BaseDatabaseOperations.integer_field_ranges,
 32        "PositiveSmallIntegerField": (0, 65535),
 33        "PositiveIntegerField": (0, 4294967295),
 34        "PositiveBigIntegerField": (0, 18446744073709551615),
 35    }
 36    cast_data_types = {
 37        "PrimaryKeyField": "signed integer",
 38        "CharField": "char(%(max_length)s)",
 39        "DecimalField": "decimal(%(max_digits)s, %(decimal_places)s)",
 40        "TextField": "char",
 41        "IntegerField": "signed integer",
 42        "BigIntegerField": "signed integer",
 43        "SmallIntegerField": "signed integer",
 44        "PositiveBigIntegerField": "unsigned integer",
 45        "PositiveIntegerField": "unsigned integer",
 46        "PositiveSmallIntegerField": "unsigned integer",
 47        "DurationField": "signed integer",
 48    }
 49    cast_char_field_without_max_length = "char"
 50    explain_prefix = "EXPLAIN"
 51
 52    # EXTRACT format cannot be passed in parameters.
 53    _extract_format_re = _lazy_re_compile(r"[A-Z_]+")
 54
 55    def date_extract_sql(
 56        self, lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
 57    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
 58        # https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
 59        if lookup_type == "week_day":
 60            # DAYOFWEEK() returns an integer, 1-7, Sunday=1.
 61            return f"DAYOFWEEK({sql})", params
 62        elif lookup_type == "iso_week_day":
 63            # WEEKDAY() returns an integer, 0-6, Monday=0.
 64            return f"WEEKDAY({sql}) + 1", params
 65        elif lookup_type == "week":
 66            # Override the value of default_week_format for consistency with
 67            # other database backends.
 68            # Mode 3: Monday, 1-53, with 4 or more days this year.
 69            return f"WEEK({sql}, 3)", params
 70        elif lookup_type == "iso_year":
 71            # Get the year part from the YEARWEEK function, which returns a
 72            # number as year * 100 + week.
 73            return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params
 74        else:
 75            # EXTRACT returns 1-53 based on ISO-8601 for the week number.
 76            lookup_type = lookup_type.upper()
 77            if not self._extract_format_re.fullmatch(lookup_type):
 78                raise ValueError(f"Invalid loookup type: {lookup_type!r}")
 79            return f"EXTRACT({lookup_type} FROM {sql})", params
 80
 81    def date_trunc_sql(
 82        self,
 83        lookup_type: str,
 84        sql: str,
 85        params: list[Any] | tuple[Any, ...],
 86        tzname: str | None = None,
 87    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
 88        sql, params = self._convert_sql_to_tz(sql, params, tzname)
 89        fields = {
 90            "year": "%Y-01-01",
 91            "month": "%Y-%m-01",
 92        }
 93        if lookup_type in fields:
 94            format_str = fields[lookup_type]
 95            return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str)
 96        elif lookup_type == "quarter":
 97            return (
 98                f"MAKEDATE(YEAR({sql}), 1) + "
 99                f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER",
100                (*params, *params),
101            )
102        elif lookup_type == "week":
103            return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params)
104        else:
105            return f"DATE({sql})", params
106
107    def _prepare_tzname_delta(self, tzname: str) -> str:
108        tzname, sign, offset = split_tzname_delta(tzname)
109        return f"{sign}{offset}" if offset else tzname
110
111    def _convert_sql_to_tz(
112        self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
113    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
114        if tzname and self.connection.timezone_name != tzname:
115            return f"CONVERT_TZ({sql}, %s, %s)", (
116                *params,
117                self.connection.timezone_name,
118                self._prepare_tzname_delta(tzname),
119            )
120        return sql, params
121
122    def datetime_cast_date_sql(
123        self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
124    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
125        sql, params = self._convert_sql_to_tz(sql, params, tzname)
126        return f"DATE({sql})", params
127
128    def datetime_cast_time_sql(
129        self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
130    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
131        sql, params = self._convert_sql_to_tz(sql, params, tzname)
132        return f"TIME({sql})", params
133
134    def datetime_extract_sql(
135        self,
136        lookup_type: str,
137        sql: str,
138        params: list[Any] | tuple[Any, ...],
139        tzname: str | None,
140    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
141        sql, params = self._convert_sql_to_tz(sql, params, tzname)
142        return self.date_extract_sql(lookup_type, sql, params)
143
144    def datetime_trunc_sql(
145        self,
146        lookup_type: str,
147        sql: str,
148        params: list[Any] | tuple[Any, ...],
149        tzname: str | None,
150    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
151        sql, params = self._convert_sql_to_tz(sql, params, tzname)
152        fields = ["year", "month", "day", "hour", "minute", "second"]
153        format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s")
154        format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
155        if lookup_type == "quarter":
156            return (
157                f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + "
158                f"INTERVAL QUARTER({sql}) QUARTER - "
159                f"INTERVAL 1 QUARTER, %s) AS DATETIME)"
160            ), (*params, *params, "%Y-%m-01 00:00:00")
161        if lookup_type == "week":
162            return (
163                f"CAST(DATE_FORMAT("
164                f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)"
165            ), (*params, *params, "%Y-%m-%d 00:00:00")
166        try:
167            i = fields.index(lookup_type) + 1
168        except ValueError:
169            pass
170        else:
171            format_str = "".join(format[:i] + format_def[i:])
172            return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str)
173        return sql, params
174
175    def time_trunc_sql(
176        self,
177        lookup_type: str,
178        sql: str,
179        params: list[Any] | tuple[Any, ...],
180        tzname: str | None = None,
181    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
182        sql, params = self._convert_sql_to_tz(sql, params, tzname)
183        fields = {
184            "hour": "%H:00:00",
185            "minute": "%H:%i:00",
186            "second": "%H:%i:%s",
187        }
188        if lookup_type in fields:
189            format_str = fields[lookup_type]
190            return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str)
191        else:
192            return f"TIME({sql})", params
193
194    def fetch_returned_insert_rows(self, cursor: CursorWrapper) -> list[Any]:
195        """
196        Given a cursor object that has just performed an INSERT...RETURNING
197        statement into a table, return the tuple of returned data.
198        """
199        return cursor.fetchall()
200
201    def format_for_duration_arithmetic(self, sql: str) -> str:
202        return f"INTERVAL {sql} MICROSECOND"
203
204    def force_no_ordering(self) -> list[tuple[None, tuple[str, list[Any], bool]]]:
205        """
206        "ORDER BY NULL" prevents MySQL from implicitly ordering by grouped
207        columns. If no ordering would otherwise be applied, we don't want any
208        implicit sorting going on.
209        """
210        return [(None, ("NULL", [], False))]
211
212    def adapt_decimalfield_value(
213        self,
214        value: Any,
215        max_digits: int | None = None,
216        decimal_places: int | None = None,
217    ) -> Any:
218        return value
219
220    def last_executed_query(
221        self, cursor: CursorWrapper, sql: str, params: Any
222    ) -> str | None:
223        # With MySQLdb, cursor objects have an (undocumented) "_executed"
224        # attribute where the exact query sent to the database is saved.
225        # See MySQLdb/cursors.py in the source distribution.
226        # MySQLdb returns string, PyMySQL bytes.
227        return force_str(getattr(cursor, "_executed", None), errors="replace")
228
229    def no_limit_value(self) -> int:
230        # 2**64 - 1, as recommended by the MySQL documentation
231        return 18446744073709551615
232
233    def quote_name(self, name: str) -> str:
234        if name.startswith("`") and name.endswith("`"):
235            return name  # Quoting once is enough.
236        return f"`{name}`"
237
238    def return_insert_columns(self, fields: list[Any]) -> tuple[str, tuple[Any, ...]]:
239        # MySQL and MariaDB < 10.5.0 don't support an INSERT...RETURNING
240        # statement.
241        if not fields:
242            return "", ()
243        columns = [
244            f"{self.quote_name(field.model.model_options.db_table)}.{self.quote_name(field.column)}"
245            for field in fields
246        ]
247        return "RETURNING {}".format(", ".join(columns)), ()
248
249    def validate_autopk_value(self, value: int) -> int:
250        # Zero in AUTO_INCREMENT field does not work without the
251        # NO_AUTO_VALUE_ON_ZERO SQL mode.
252        if value == 0 and not self.connection.features.allows_auto_pk_0:
253            raise ValueError(
254                "The database backend does not accept 0 as a value for PrimaryKeyField."
255            )
256        return value
257
258    def adapt_datetimefield_value(
259        self, value: datetime.datetime | Any | None
260    ) -> str | Any | None:
261        if value is None:
262            return None
263
264        # Expression values are adapted by the database.
265        if isinstance(value, ResolvableExpression):
266            return value
267
268        # MySQL doesn't support tz-aware datetimes
269        if timezone.is_aware(value):
270            value = timezone.make_naive(value, self.connection.timezone)
271        return str(value)
272
273    def adapt_timefield_value(
274        self, value: datetime.time | Any | None
275    ) -> str | Any | None:
276        if value is None:
277            return None
278
279        # Expression values are adapted by the database.
280        if isinstance(value, ResolvableExpression):
281            return value
282
283        # MySQL doesn't support tz-aware times
284        if timezone.is_aware(value):
285            raise ValueError("MySQL backend does not support timezone-aware times.")
286
287        return value.isoformat(timespec="microseconds")
288
289    def max_name_length(self) -> int:
290        return 64
291
292    def pk_default_value(self) -> str:
293        return "NULL"
294
295    def bulk_insert_sql(
296        self, fields: list[Any], placeholder_rows: list[list[str]]
297    ) -> str:
298        placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
299        values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
300        return "VALUES " + values_sql
301
302    def combine_expression(self, connector: str, sub_expressions: list[str]) -> str:
303        if connector == "^":
304            return "POW({})".format(",".join(sub_expressions))
305        # Convert the result to a signed integer since MySQL's binary operators
306        # return an unsigned integer.
307        elif connector in ("&", "|", "<<", "#"):
308            connector = "^" if connector == "#" else connector
309            return f"CONVERT({connector.join(sub_expressions)}, SIGNED)"
310        elif connector == ">>":
311            lhs, rhs = sub_expressions
312            return f"FLOOR({lhs} / POW(2, {rhs}))"
313        return super().combine_expression(connector, sub_expressions)
314
315    def get_db_converters(self, expression: Any) -> list[Any]:
316        converters = super().get_db_converters(expression)
317        internal_type = expression.output_field.get_internal_type()
318        if internal_type == "BooleanField":
319            converters.append(self.convert_booleanfield_value)
320        elif internal_type == "DateTimeField":
321            converters.append(self.convert_datetimefield_value)
322        elif internal_type == "UUIDField":
323            converters.append(self.convert_uuidfield_value)
324        return converters
325
326    def convert_booleanfield_value(
327        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
328    ) -> Any:
329        if value in (0, 1):
330            value = bool(value)
331        return value
332
333    def convert_datetimefield_value(
334        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
335    ) -> datetime.datetime | None:
336        if value is not None:
337            value = timezone.make_aware(value, self.connection.timezone)
338        return value
339
340    def convert_uuidfield_value(
341        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
342    ) -> uuid.UUID | None:
343        if value is not None:
344            value = uuid.UUID(value)
345        return value
346
347    def binary_placeholder_sql(self, value: Any) -> str:
348        return (
349            "_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s"
350        )
351
352    def subtract_temporals(
353        self,
354        internal_type: str,
355        lhs: tuple[str, list[Any] | tuple[Any, ...]],
356        rhs: tuple[str, list[Any] | tuple[Any, ...]],
357    ) -> tuple[str, tuple[Any, ...]]:
358        lhs_sql, lhs_params = lhs
359        rhs_sql, rhs_params = rhs
360        if internal_type == "TimeField":
361            if self.connection.mysql_is_mariadb:
362                # MariaDB includes the microsecond component in TIME_TO_SEC as
363                # a decimal. MySQL returns an integer without microseconds.
364                return (
365                    f"CAST((TIME_TO_SEC({lhs_sql}) - TIME_TO_SEC({rhs_sql})) "
366                    "* 1000000 AS SIGNED)"
367                ), (
368                    *lhs_params,
369                    *rhs_params,
370                )
371            return (
372                f"((TIME_TO_SEC({lhs_sql}) * 1000000 + MICROSECOND({lhs_sql})) -"
373                f" (TIME_TO_SEC({rhs_sql}) * 1000000 + MICROSECOND({rhs_sql})))"
374            ), tuple(lhs_params) * 2 + tuple(rhs_params) * 2
375        params = (*rhs_params, *lhs_params)
376        return f"TIMESTAMPDIFF(MICROSECOND, {rhs_sql}, {lhs_sql})", params
377
378    def explain_query_prefix(self, format: str | None = None, **options: Any) -> str:
379        # Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
380        if format and format.upper() == "TEXT":
381            format = "TRADITIONAL"
382        elif (
383            not format and "TREE" in self.connection.features.supported_explain_formats
384        ):
385            # Use TREE by default (if supported) as it's more informative.
386            format = "TREE"
387        analyze = options.pop("analyze", False)
388        prefix = super().explain_query_prefix(format, **options)
389        if analyze and self.connection.features.supports_explain_analyze:
390            # MariaDB uses ANALYZE instead of EXPLAIN ANALYZE.
391            prefix = (
392                "ANALYZE" if self.connection.mysql_is_mariadb else prefix + " ANALYZE"
393            )
394        if format and not (analyze and not self.connection.mysql_is_mariadb):
395            # Only MariaDB supports the analyze option with formats.
396            prefix += f" FORMAT={format}"
397        return prefix
398
399    def regex_lookup(self, lookup_type: str) -> str:
400        # REGEXP_LIKE doesn't exist in MariaDB.
401        if self.connection.mysql_is_mariadb:
402            if lookup_type == "regex":
403                return "%s REGEXP BINARY %s"
404            return "%s REGEXP %s"
405
406        match_option = "c" if lookup_type == "regex" else "i"
407        return f"REGEXP_LIKE(%s, %s, '{match_option}')"
408
409    def insert_statement(self, on_conflict: Any = None) -> str:
410        if on_conflict == OnConflict.IGNORE:
411            return "INSERT IGNORE INTO"
412        return super().insert_statement(on_conflict=on_conflict)
413
414    def lookup_cast(self, lookup_type: str, internal_type: str | None = None) -> str:
415        lookup = "%s"
416        if internal_type == "JSONField":
417            if self.connection.mysql_is_mariadb or lookup_type in (
418                "iexact",
419                "contains",
420                "icontains",
421                "startswith",
422                "istartswith",
423                "endswith",
424                "iendswith",
425                "regex",
426                "iregex",
427            ):
428                lookup = "JSON_UNQUOTE(%s)"
429        return lookup
430
431    def conditional_expression_supported_in_where_clause(self, expression: Any) -> bool:
432        # MySQL ignores indexes with boolean fields unless they're compared
433        # directly to a boolean value.
434        if isinstance(expression, Exists | Lookup):
435            return True
436        if isinstance(expression, ExpressionWrapper) and expression.conditional:
437            return self.conditional_expression_supported_in_where_clause(
438                expression.expression
439            )
440        if getattr(expression, "conditional", False):
441            return False
442        return super().conditional_expression_supported_in_where_clause(expression)
443
444    def on_conflict_suffix_sql(
445        self,
446        fields: list[Field],
447        on_conflict: Any,
448        update_fields: Iterable[str],
449        unique_fields: Iterable[str],
450    ) -> str:
451        if on_conflict == OnConflict.UPDATE:
452            conflict_suffix_sql = "ON DUPLICATE KEY UPDATE %(fields)s"
453            # The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
454            # aliases for the new row and its columns available in MySQL
455            # 8.0.19+.
456            if not self.connection.mysql_is_mariadb:
457                if self.connection.mysql_version >= (8, 0, 19):
458                    conflict_suffix_sql = f"AS new {conflict_suffix_sql}"
459                    field_sql = "%(field)s = new.%(field)s"
460                else:
461                    field_sql = "%(field)s = VALUES(%(field)s)"
462            # Use VALUE() on MariaDB.
463            else:
464                field_sql = "%(field)s = VALUE(%(field)s)"
465
466            fields_str = ", ".join(
467                [
468                    field_sql % {"field": field}
469                    for field in map(self.quote_name, update_fields)
470                ]
471            )
472            return conflict_suffix_sql % {"fields": fields_str}
473        return super().on_conflict_suffix_sql(
474            fields,
475            on_conflict,
476            update_fields,
477            unique_fields,
478        )