Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import datetime
  4import uuid
  5from collections.abc import Iterable
  6from typing import TYPE_CHECKING, Any
  7
  8from plain.models.backends.base.operations import BaseDatabaseOperations
  9from plain.models.backends.utils import split_tzname_delta
 10from plain.models.constants import OnConflict
 11from plain.models.expressions import Exists, ExpressionWrapper
 12from plain.models.lookups import Lookup
 13from plain.utils import timezone
 14from plain.utils.encoding import force_str
 15from plain.utils.regex_helper import _lazy_re_compile
 16
 17if TYPE_CHECKING:
 18    from plain.models.backends.base.base import BaseDatabaseWrapper
 19    from plain.models.backends.mysql.base import MySQLDatabaseWrapper
 20    from plain.models.fields import Field
 21
 22
 23class DatabaseOperations(BaseDatabaseOperations):
 24    # Type checker hint: connection is always MySQLDatabaseWrapper in this class
 25    connection: MySQLDatabaseWrapper
 26
 27    compiler_module = "plain.models.backends.mysql.compiler"
 28
 29    # MySQL stores positive fields as UNSIGNED ints.
 30    integer_field_ranges = {
 31        **BaseDatabaseOperations.integer_field_ranges,
 32        "PositiveSmallIntegerField": (0, 65535),
 33        "PositiveIntegerField": (0, 4294967295),
 34        "PositiveBigIntegerField": (0, 18446744073709551615),
 35    }
 36    cast_data_types = {
 37        "PrimaryKeyField": "signed integer",
 38        "CharField": "char(%(max_length)s)",
 39        "DecimalField": "decimal(%(max_digits)s, %(decimal_places)s)",
 40        "TextField": "char",
 41        "IntegerField": "signed integer",
 42        "BigIntegerField": "signed integer",
 43        "SmallIntegerField": "signed integer",
 44        "PositiveBigIntegerField": "unsigned integer",
 45        "PositiveIntegerField": "unsigned integer",
 46        "PositiveSmallIntegerField": "unsigned integer",
 47        "DurationField": "signed integer",
 48    }
 49    cast_char_field_without_max_length = "char"
 50    explain_prefix = "EXPLAIN"
 51
 52    # EXTRACT format cannot be passed in parameters.
 53    _extract_format_re = _lazy_re_compile(r"[A-Z_]+")
 54
 55    def date_extract_sql(
 56        self, lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
 57    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
 58        # https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
 59        if lookup_type == "week_day":
 60            # DAYOFWEEK() returns an integer, 1-7, Sunday=1.
 61            return f"DAYOFWEEK({sql})", params
 62        elif lookup_type == "iso_week_day":
 63            # WEEKDAY() returns an integer, 0-6, Monday=0.
 64            return f"WEEKDAY({sql}) + 1", params
 65        elif lookup_type == "week":
 66            # Override the value of default_week_format for consistency with
 67            # other database backends.
 68            # Mode 3: Monday, 1-53, with 4 or more days this year.
 69            return f"WEEK({sql}, 3)", params
 70        elif lookup_type == "iso_year":
 71            # Get the year part from the YEARWEEK function, which returns a
 72            # number as year * 100 + week.
 73            return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params
 74        else:
 75            # EXTRACT returns 1-53 based on ISO-8601 for the week number.
 76            lookup_type = lookup_type.upper()
 77            if not self._extract_format_re.fullmatch(lookup_type):
 78                raise ValueError(f"Invalid loookup type: {lookup_type!r}")
 79            return f"EXTRACT({lookup_type} FROM {sql})", params
 80
 81    def date_trunc_sql(
 82        self,
 83        lookup_type: str,
 84        sql: str,
 85        params: list[Any] | tuple[Any, ...],
 86        tzname: str | None = None,
 87    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
 88        sql, params = self._convert_sql_to_tz(sql, params, tzname)
 89        fields = {
 90            "year": "%Y-01-01",
 91            "month": "%Y-%m-01",
 92        }
 93        if lookup_type in fields:
 94            format_str = fields[lookup_type]
 95            return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str)
 96        elif lookup_type == "quarter":
 97            return (
 98                f"MAKEDATE(YEAR({sql}), 1) + "
 99                f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER",
100                (*params, *params),
101            )
102        elif lookup_type == "week":
103            return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params)
104        else:
105            return f"DATE({sql})", params
106
107    def _prepare_tzname_delta(self, tzname: str) -> str:
108        tzname, sign, offset = split_tzname_delta(tzname)
109        return f"{sign}{offset}" if offset else tzname
110
111    def _convert_sql_to_tz(
112        self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
113    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
114        if tzname and self.connection.timezone_name != tzname:
115            return f"CONVERT_TZ({sql}, %s, %s)", (
116                *params,
117                self.connection.timezone_name,
118                self._prepare_tzname_delta(tzname),
119            )
120        return sql, params
121
122    def datetime_cast_date_sql(
123        self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
124    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
125        sql, params = self._convert_sql_to_tz(sql, params, tzname)
126        return f"DATE({sql})", params
127
128    def datetime_cast_time_sql(
129        self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
130    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
131        sql, params = self._convert_sql_to_tz(sql, params, tzname)
132        return f"TIME({sql})", params
133
134    def datetime_extract_sql(
135        self,
136        lookup_type: str,
137        sql: str,
138        params: list[Any] | tuple[Any, ...],
139        tzname: str | None,
140    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
141        sql, params = self._convert_sql_to_tz(sql, params, tzname)
142        return self.date_extract_sql(lookup_type, sql, params)
143
144    def datetime_trunc_sql(
145        self,
146        lookup_type: str,
147        sql: str,
148        params: list[Any] | tuple[Any, ...],
149        tzname: str | None,
150    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
151        sql, params = self._convert_sql_to_tz(sql, params, tzname)
152        fields = ["year", "month", "day", "hour", "minute", "second"]
153        format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s")
154        format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
155        if lookup_type == "quarter":
156            return (
157                f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + "
158                f"INTERVAL QUARTER({sql}) QUARTER - "
159                f"INTERVAL 1 QUARTER, %s) AS DATETIME)"
160            ), (*params, *params, "%Y-%m-01 00:00:00")
161        if lookup_type == "week":
162            return (
163                f"CAST(DATE_FORMAT("
164                f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)"
165            ), (*params, *params, "%Y-%m-%d 00:00:00")
166        try:
167            i = fields.index(lookup_type) + 1
168        except ValueError:
169            pass
170        else:
171            format_str = "".join(format[:i] + format_def[i:])
172            return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str)
173        return sql, params
174
175    def time_trunc_sql(
176        self,
177        lookup_type: str,
178        sql: str,
179        params: list[Any] | tuple[Any, ...],
180        tzname: str | None = None,
181    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
182        sql, params = self._convert_sql_to_tz(sql, params, tzname)
183        fields = {
184            "hour": "%H:00:00",
185            "minute": "%H:%i:00",
186            "second": "%H:%i:%s",
187        }
188        if lookup_type in fields:
189            format_str = fields[lookup_type]
190            return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str)
191        else:
192            return f"TIME({sql})", params
193
194    def fetch_returned_insert_rows(self, cursor: Any) -> list[Any]:
195        """
196        Given a cursor object that has just performed an INSERT...RETURNING
197        statement into a table, return the tuple of returned data.
198        """
199        return cursor.fetchall()
200
201    def format_for_duration_arithmetic(self, sql: str) -> str:
202        return f"INTERVAL {sql} MICROSECOND"
203
204    def force_no_ordering(self) -> list[tuple[None, tuple[str, list[Any], bool]]]:
205        """
206        "ORDER BY NULL" prevents MySQL from implicitly ordering by grouped
207        columns. If no ordering would otherwise be applied, we don't want any
208        implicit sorting going on.
209        """
210        return [(None, ("NULL", [], False))]
211
212    def adapt_decimalfield_value(
213        self,
214        value: Any,
215        max_digits: int | None = None,
216        decimal_places: int | None = None,
217    ) -> Any:
218        return value
219
220    def last_executed_query(self, cursor: Any, sql: str, params: Any) -> str | None:
221        # With MySQLdb, cursor objects have an (undocumented) "_executed"
222        # attribute where the exact query sent to the database is saved.
223        # See MySQLdb/cursors.py in the source distribution.
224        # MySQLdb returns string, PyMySQL bytes.
225        return force_str(getattr(cursor, "_executed", None), errors="replace")
226
227    def no_limit_value(self) -> int:
228        # 2**64 - 1, as recommended by the MySQL documentation
229        return 18446744073709551615
230
231    def quote_name(self, name: str) -> str:
232        if name.startswith("`") and name.endswith("`"):
233            return name  # Quoting once is enough.
234        return f"`{name}`"
235
236    def return_insert_columns(self, fields: list[Any]) -> tuple[str, tuple[Any, ...]]:
237        # MySQL and MariaDB < 10.5.0 don't support an INSERT...RETURNING
238        # statement.
239        if not fields:
240            return "", ()
241        columns = [
242            f"{self.quote_name(field.model.model_options.db_table)}.{self.quote_name(field.column)}"
243            for field in fields
244        ]
245        return "RETURNING {}".format(", ".join(columns)), ()
246
247    def validate_autopk_value(self, value: int) -> int:
248        # Zero in AUTO_INCREMENT field does not work without the
249        # NO_AUTO_VALUE_ON_ZERO SQL mode.
250        if value == 0 and not self.connection.features.allows_auto_pk_0:
251            raise ValueError(
252                "The database backend does not accept 0 as a value for PrimaryKeyField."
253            )
254        return value
255
256    def adapt_datetimefield_value(
257        self, value: datetime.datetime | Any | None
258    ) -> str | Any | None:
259        if value is None:
260            return None
261
262        # Expression values are adapted by the database.
263        if hasattr(value, "resolve_expression"):
264            return value
265
266        # MySQL doesn't support tz-aware datetimes
267        if timezone.is_aware(value):
268            value = timezone.make_naive(value, self.connection.timezone)
269        return str(value)
270
271    def adapt_timefield_value(
272        self, value: datetime.time | Any | None
273    ) -> str | Any | None:
274        if value is None:
275            return None
276
277        # Expression values are adapted by the database.
278        if hasattr(value, "resolve_expression"):
279            return value
280
281        # MySQL doesn't support tz-aware times
282        if timezone.is_aware(value):  # type: ignore[arg-type]
283            raise ValueError("MySQL backend does not support timezone-aware times.")
284
285        return value.isoformat(timespec="microseconds")
286
287    def max_name_length(self) -> int:
288        return 64
289
290    def pk_default_value(self) -> str:
291        return "NULL"
292
293    def bulk_insert_sql(
294        self, fields: list[Any], placeholder_rows: list[list[str]]
295    ) -> str:
296        placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
297        values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
298        return "VALUES " + values_sql
299
300    def combine_expression(self, connector: str, sub_expressions: list[str]) -> str:
301        if connector == "^":
302            return "POW({})".format(",".join(sub_expressions))
303        # Convert the result to a signed integer since MySQL's binary operators
304        # return an unsigned integer.
305        elif connector in ("&", "|", "<<", "#"):
306            connector = "^" if connector == "#" else connector
307            return f"CONVERT({connector.join(sub_expressions)}, SIGNED)"
308        elif connector == ">>":
309            lhs, rhs = sub_expressions
310            return f"FLOOR({lhs} / POW(2, {rhs}))"
311        return super().combine_expression(connector, sub_expressions)
312
313    def get_db_converters(self, expression: Any) -> list[Any]:
314        converters = super().get_db_converters(expression)
315        internal_type = expression.output_field.get_internal_type()
316        if internal_type == "BooleanField":
317            converters.append(self.convert_booleanfield_value)
318        elif internal_type == "DateTimeField":
319            converters.append(self.convert_datetimefield_value)
320        elif internal_type == "UUIDField":
321            converters.append(self.convert_uuidfield_value)
322        return converters
323
324    def convert_booleanfield_value(
325        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
326    ) -> Any:
327        if value in (0, 1):
328            value = bool(value)
329        return value
330
331    def convert_datetimefield_value(
332        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
333    ) -> datetime.datetime | None:
334        if value is not None:
335            value = timezone.make_aware(value, self.connection.timezone)
336        return value
337
338    def convert_uuidfield_value(
339        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
340    ) -> uuid.UUID | None:
341        if value is not None:
342            value = uuid.UUID(value)
343        return value
344
345    def binary_placeholder_sql(self, value: Any) -> str:
346        return (
347            "_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s"
348        )
349
350    def subtract_temporals(
351        self,
352        internal_type: str,
353        lhs: tuple[str, list[Any] | tuple[Any, ...]],
354        rhs: tuple[str, list[Any] | tuple[Any, ...]],
355    ) -> tuple[str, tuple[Any, ...]]:
356        lhs_sql, lhs_params = lhs
357        rhs_sql, rhs_params = rhs
358        if internal_type == "TimeField":
359            if self.connection.mysql_is_mariadb:
360                # MariaDB includes the microsecond component in TIME_TO_SEC as
361                # a decimal. MySQL returns an integer without microseconds.
362                return (
363                    f"CAST((TIME_TO_SEC({lhs_sql}) - TIME_TO_SEC({rhs_sql})) "
364                    "* 1000000 AS SIGNED)"
365                ), (
366                    *lhs_params,
367                    *rhs_params,
368                )
369            return (
370                f"((TIME_TO_SEC({lhs_sql}) * 1000000 + MICROSECOND({lhs_sql})) -"
371                f" (TIME_TO_SEC({rhs_sql}) * 1000000 + MICROSECOND({rhs_sql})))"
372            ), tuple(lhs_params) * 2 + tuple(rhs_params) * 2
373        params = (*rhs_params, *lhs_params)
374        return f"TIMESTAMPDIFF(MICROSECOND, {rhs_sql}, {lhs_sql})", params
375
376    def explain_query_prefix(self, format: str | None = None, **options: Any) -> str:
377        # Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
378        if format and format.upper() == "TEXT":
379            format = "TRADITIONAL"
380        elif (
381            not format and "TREE" in self.connection.features.supported_explain_formats
382        ):
383            # Use TREE by default (if supported) as it's more informative.
384            format = "TREE"
385        analyze = options.pop("analyze", False)
386        prefix = super().explain_query_prefix(format, **options)
387        if analyze and self.connection.features.supports_explain_analyze:
388            # MariaDB uses ANALYZE instead of EXPLAIN ANALYZE.
389            prefix = (
390                "ANALYZE" if self.connection.mysql_is_mariadb else prefix + " ANALYZE"
391            )
392        if format and not (analyze and not self.connection.mysql_is_mariadb):
393            # Only MariaDB supports the analyze option with formats.
394            prefix += f" FORMAT={format}"
395        return prefix
396
397    def regex_lookup(self, lookup_type: str) -> str:
398        # REGEXP_LIKE doesn't exist in MariaDB.
399        if self.connection.mysql_is_mariadb:
400            if lookup_type == "regex":
401                return "%s REGEXP BINARY %s"
402            return "%s REGEXP %s"
403
404        match_option = "c" if lookup_type == "regex" else "i"
405        return f"REGEXP_LIKE(%s, %s, '{match_option}')"
406
407    def insert_statement(self, on_conflict: Any = None) -> str:
408        if on_conflict == OnConflict.IGNORE:
409            return "INSERT IGNORE INTO"
410        return super().insert_statement(on_conflict=on_conflict)
411
412    def lookup_cast(self, lookup_type: str, internal_type: str | None = None) -> str:
413        lookup = "%s"
414        if internal_type == "JSONField":
415            if self.connection.mysql_is_mariadb or lookup_type in (
416                "iexact",
417                "contains",
418                "icontains",
419                "startswith",
420                "istartswith",
421                "endswith",
422                "iendswith",
423                "regex",
424                "iregex",
425            ):
426                lookup = "JSON_UNQUOTE(%s)"
427        return lookup
428
429    def conditional_expression_supported_in_where_clause(self, expression: Any) -> bool:
430        # MySQL ignores indexes with boolean fields unless they're compared
431        # directly to a boolean value.
432        if isinstance(expression, Exists | Lookup):
433            return True
434        if isinstance(expression, ExpressionWrapper) and expression.conditional:
435            return self.conditional_expression_supported_in_where_clause(
436                expression.expression
437            )
438        if getattr(expression, "conditional", False):
439            return False
440        return super().conditional_expression_supported_in_where_clause(expression)
441
442    def on_conflict_suffix_sql(
443        self,
444        fields: list[Field],
445        on_conflict: Any,
446        update_fields: Iterable[str],
447        unique_fields: Iterable[str],
448    ) -> str:
449        if on_conflict == OnConflict.UPDATE:
450            conflict_suffix_sql = "ON DUPLICATE KEY UPDATE %(fields)s"
451            # The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
452            # aliases for the new row and its columns available in MySQL
453            # 8.0.19+.
454            if not self.connection.mysql_is_mariadb:
455                if self.connection.mysql_version >= (8, 0, 19):
456                    conflict_suffix_sql = f"AS new {conflict_suffix_sql}"
457                    field_sql = "%(field)s = new.%(field)s"
458                else:
459                    field_sql = "%(field)s = VALUES(%(field)s)"
460            # Use VALUE() on MariaDB.
461            else:
462                field_sql = "%(field)s = VALUE(%(field)s)"
463
464            fields_str = ", ".join(
465                [
466                    field_sql % {"field": field}
467                    for field in map(self.quote_name, update_fields)
468                ]
469            )
470            return conflict_suffix_sql % {"fields": fields_str}
471        return super().on_conflict_suffix_sql(
472            fields,
473            on_conflict,
474            update_fields,
475            unique_fields,
476        )