Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import datetime
  4import uuid
  5from collections.abc import Iterable
  6from functools import cached_property
  7from typing import TYPE_CHECKING, Any
  8
  9from plain.models.backends.base.operations import BaseDatabaseOperations
 10from plain.models.backends.utils import CursorWrapper, split_tzname_delta
 11from plain.models.constants import OnConflict
 12from plain.models.expressions import Exists, ExpressionWrapper, ResolvableExpression
 13from plain.models.lookups import Lookup
 14from plain.utils import timezone
 15from plain.utils.encoding import force_str
 16from plain.utils.regex_helper import _lazy_re_compile
 17
 18if TYPE_CHECKING:
 19    from plain.models.backends.base.base import BaseDatabaseWrapper
 20    from plain.models.backends.mysql.base import MySQLDatabaseWrapper
 21    from plain.models.fields import Field
 22    from plain.models.sql.compiler import SQLCompiler
 23    from plain.models.sql.query import Query
 24
 25
 26class DatabaseOperations(BaseDatabaseOperations):
 27    # Type checker hint: connection is always MySQLDatabaseWrapper in this class
 28    connection: MySQLDatabaseWrapper
 29
 30    @cached_property
 31    def compilers(self) -> dict[type[Query], type[SQLCompiler]]:
 32        from plain.models.backends.mysql import compiler as mysql_compiler
 33        from plain.models.sql.query import Query
 34        from plain.models.sql.subqueries import (
 35            AggregateQuery,
 36            DeleteQuery,
 37            InsertQuery,
 38            UpdateQuery,
 39        )
 40
 41        return {
 42            Query: mysql_compiler.SQLCompiler,
 43            DeleteQuery: mysql_compiler.SQLDeleteCompiler,
 44            UpdateQuery: mysql_compiler.SQLUpdateCompiler,
 45            InsertQuery: mysql_compiler.SQLInsertCompiler,
 46            AggregateQuery: mysql_compiler.SQLAggregateCompiler,
 47        }
 48
 49    # MySQL stores positive fields as UNSIGNED ints.
 50    integer_field_ranges = {
 51        **BaseDatabaseOperations.integer_field_ranges,
 52        "PositiveSmallIntegerField": (0, 65535),
 53        "PositiveIntegerField": (0, 4294967295),
 54        "PositiveBigIntegerField": (0, 18446744073709551615),
 55    }
 56    cast_data_types = {
 57        "PrimaryKeyField": "signed integer",
 58        "CharField": "char(%(max_length)s)",
 59        "DecimalField": "decimal(%(max_digits)s, %(decimal_places)s)",
 60        "TextField": "char",
 61        "IntegerField": "signed integer",
 62        "BigIntegerField": "signed integer",
 63        "SmallIntegerField": "signed integer",
 64        "PositiveBigIntegerField": "unsigned integer",
 65        "PositiveIntegerField": "unsigned integer",
 66        "PositiveSmallIntegerField": "unsigned integer",
 67        "DurationField": "signed integer",
 68    }
 69    cast_char_field_without_max_length = "char"
 70    explain_prefix = "EXPLAIN"
 71
 72    # EXTRACT format cannot be passed in parameters.
 73    _extract_format_re = _lazy_re_compile(r"[A-Z_]+")
 74
 75    def date_extract_sql(
 76        self, lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
 77    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
 78        # https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
 79        if lookup_type == "week_day":
 80            # DAYOFWEEK() returns an integer, 1-7, Sunday=1.
 81            return f"DAYOFWEEK({sql})", params
 82        elif lookup_type == "iso_week_day":
 83            # WEEKDAY() returns an integer, 0-6, Monday=0.
 84            return f"WEEKDAY({sql}) + 1", params
 85        elif lookup_type == "week":
 86            # Override the value of default_week_format for consistency with
 87            # other database backends.
 88            # Mode 3: Monday, 1-53, with 4 or more days this year.
 89            return f"WEEK({sql}, 3)", params
 90        elif lookup_type == "iso_year":
 91            # Get the year part from the YEARWEEK function, which returns a
 92            # number as year * 100 + week.
 93            return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params
 94        else:
 95            # EXTRACT returns 1-53 based on ISO-8601 for the week number.
 96            lookup_type = lookup_type.upper()
 97            if not self._extract_format_re.fullmatch(lookup_type):
 98                raise ValueError(f"Invalid loookup type: {lookup_type!r}")
 99            return f"EXTRACT({lookup_type} FROM {sql})", params
100
101    def date_trunc_sql(
102        self,
103        lookup_type: str,
104        sql: str,
105        params: list[Any] | tuple[Any, ...],
106        tzname: str | None = None,
107    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
108        sql, params = self._convert_sql_to_tz(sql, params, tzname)
109        fields = {
110            "year": "%Y-01-01",
111            "month": "%Y-%m-01",
112        }
113        if lookup_type in fields:
114            format_str = fields[lookup_type]
115            return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str)
116        elif lookup_type == "quarter":
117            return (
118                f"MAKEDATE(YEAR({sql}), 1) + "
119                f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER",
120                (*params, *params),
121            )
122        elif lookup_type == "week":
123            return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params)
124        else:
125            return f"DATE({sql})", params
126
127    def _prepare_tzname_delta(self, tzname: str) -> str:
128        tzname, sign, offset = split_tzname_delta(tzname)
129        return f"{sign}{offset}" if offset else tzname
130
131    def _convert_sql_to_tz(
132        self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
133    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
134        if tzname and self.connection.timezone_name != tzname:
135            return f"CONVERT_TZ({sql}, %s, %s)", (
136                *params,
137                self.connection.timezone_name,
138                self._prepare_tzname_delta(tzname),
139            )
140        return sql, params
141
142    def datetime_cast_date_sql(
143        self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
144    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
145        sql, params = self._convert_sql_to_tz(sql, params, tzname)
146        return f"DATE({sql})", params
147
148    def datetime_cast_time_sql(
149        self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
150    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
151        sql, params = self._convert_sql_to_tz(sql, params, tzname)
152        return f"TIME({sql})", params
153
154    def datetime_extract_sql(
155        self,
156        lookup_type: str,
157        sql: str,
158        params: list[Any] | tuple[Any, ...],
159        tzname: str | None,
160    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
161        sql, params = self._convert_sql_to_tz(sql, params, tzname)
162        return self.date_extract_sql(lookup_type, sql, params)
163
164    def datetime_trunc_sql(
165        self,
166        lookup_type: str,
167        sql: str,
168        params: list[Any] | tuple[Any, ...],
169        tzname: str | None,
170    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
171        sql, params = self._convert_sql_to_tz(sql, params, tzname)
172        fields = ["year", "month", "day", "hour", "minute", "second"]
173        format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s")
174        format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
175        if lookup_type == "quarter":
176            return (
177                f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + "
178                f"INTERVAL QUARTER({sql}) QUARTER - "
179                f"INTERVAL 1 QUARTER, %s) AS DATETIME)"
180            ), (*params, *params, "%Y-%m-01 00:00:00")
181        if lookup_type == "week":
182            return (
183                f"CAST(DATE_FORMAT("
184                f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)"
185            ), (*params, *params, "%Y-%m-%d 00:00:00")
186        try:
187            i = fields.index(lookup_type) + 1
188        except ValueError:
189            pass
190        else:
191            format_str = "".join(format[:i] + format_def[i:])
192            return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str)
193        return sql, params
194
195    def time_trunc_sql(
196        self,
197        lookup_type: str,
198        sql: str,
199        params: list[Any] | tuple[Any, ...],
200        tzname: str | None = None,
201    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
202        sql, params = self._convert_sql_to_tz(sql, params, tzname)
203        fields = {
204            "hour": "%H:00:00",
205            "minute": "%H:%i:00",
206            "second": "%H:%i:%s",
207        }
208        if lookup_type in fields:
209            format_str = fields[lookup_type]
210            return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str)
211        else:
212            return f"TIME({sql})", params
213
214    def fetch_returned_insert_rows(self, cursor: CursorWrapper) -> list[Any]:
215        """
216        Given a cursor object that has just performed an INSERT...RETURNING
217        statement into a table, return the tuple of returned data.
218        """
219        return cursor.fetchall()
220
221    def format_for_duration_arithmetic(self, sql: str) -> str:
222        return f"INTERVAL {sql} MICROSECOND"
223
224    def force_no_ordering(
225        self,
226    ) -> list[tuple[Any, tuple[str, tuple[Any, ...], bool]]]:
227        """
228        "ORDER BY NULL" prevents MySQL from implicitly ordering by grouped
229        columns. If no ordering would otherwise be applied, we don't want any
230        implicit sorting going on.
231        """
232        return [(None, ("NULL", (), False))]
233
234    def adapt_decimalfield_value(
235        self,
236        value: Any,
237        max_digits: int | None = None,
238        decimal_places: int | None = None,
239    ) -> Any:
240        return value
241
242    def last_executed_query(
243        self, cursor: CursorWrapper, sql: str, params: Any
244    ) -> str | None:
245        # With MySQLdb, cursor objects have an (undocumented) "_executed"
246        # attribute where the exact query sent to the database is saved.
247        # See MySQLdb/cursors.py in the source distribution.
248        # MySQLdb returns string, PyMySQL bytes.
249        return force_str(getattr(cursor, "_executed", None), errors="replace")
250
251    def no_limit_value(self) -> int:
252        # 2**64 - 1, as recommended by the MySQL documentation
253        return 18446744073709551615
254
255    def quote_name(self, name: str) -> str:
256        if name.startswith("`") and name.endswith("`"):
257            return name  # Quoting once is enough.
258        return f"`{name}`"
259
260    def return_insert_columns(self, fields: list[Any]) -> tuple[str, tuple[Any, ...]]:
261        # MySQL and MariaDB < 10.5.0 don't support an INSERT...RETURNING
262        # statement.
263        if not fields:
264            return "", ()
265        columns = [
266            f"{self.quote_name(field.model.model_options.db_table)}.{self.quote_name(field.column)}"
267            for field in fields
268        ]
269        return "RETURNING {}".format(", ".join(columns)), ()
270
271    def validate_autopk_value(self, value: int) -> int:
272        # Zero in AUTO_INCREMENT field does not work without the
273        # NO_AUTO_VALUE_ON_ZERO SQL mode.
274        if value == 0 and not self.connection.features.allows_auto_pk_0:
275            raise ValueError(
276                "The database backend does not accept 0 as a value for PrimaryKeyField."
277            )
278        return value
279
280    def adapt_datetimefield_value(
281        self, value: datetime.datetime | Any | None
282    ) -> str | Any | None:
283        if value is None:
284            return None
285
286        # Expression values are adapted by the database.
287        if isinstance(value, ResolvableExpression):
288            return value
289
290        # MySQL doesn't support tz-aware datetimes
291        if timezone.is_aware(value):
292            value = timezone.make_naive(value, self.connection.timezone)
293        return str(value)
294
295    def adapt_timefield_value(
296        self, value: datetime.time | Any | None
297    ) -> str | Any | None:
298        if value is None:
299            return None
300
301        # Expression values are adapted by the database.
302        if isinstance(value, ResolvableExpression):
303            return value
304
305        # MySQL doesn't support tz-aware times
306        if timezone.is_aware(value):
307            raise ValueError("MySQL backend does not support timezone-aware times.")
308
309        return value.isoformat(timespec="microseconds")
310
311    def max_name_length(self) -> int:
312        return 64
313
314    def pk_default_value(self) -> str:
315        return "NULL"
316
317    def bulk_insert_sql(
318        self, fields: list[Any], placeholder_rows: list[list[str]]
319    ) -> str:
320        placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
321        values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
322        return "VALUES " + values_sql
323
324    def combine_expression(self, connector: str, sub_expressions: list[str]) -> str:
325        if connector == "^":
326            return "POW({})".format(",".join(sub_expressions))
327        # Convert the result to a signed integer since MySQL's binary operators
328        # return an unsigned integer.
329        elif connector in ("&", "|", "<<", "#"):
330            connector = "^" if connector == "#" else connector
331            return f"CONVERT({connector.join(sub_expressions)}, SIGNED)"
332        elif connector == ">>":
333            lhs, rhs = sub_expressions
334            return f"FLOOR({lhs} / POW(2, {rhs}))"
335        return super().combine_expression(connector, sub_expressions)
336
337    def get_db_converters(self, expression: Any) -> list[Any]:
338        converters = super().get_db_converters(expression)
339        internal_type = expression.output_field.get_internal_type()
340        if internal_type == "BooleanField":
341            converters.append(self.convert_booleanfield_value)
342        elif internal_type == "DateTimeField":
343            converters.append(self.convert_datetimefield_value)
344        elif internal_type == "UUIDField":
345            converters.append(self.convert_uuidfield_value)
346        return converters
347
348    def convert_booleanfield_value(
349        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
350    ) -> Any:
351        if value in (0, 1):
352            value = bool(value)
353        return value
354
355    def convert_datetimefield_value(
356        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
357    ) -> datetime.datetime | None:
358        if value is not None:
359            value = timezone.make_aware(value, self.connection.timezone)
360        return value
361
362    def convert_uuidfield_value(
363        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
364    ) -> uuid.UUID | None:
365        if value is not None:
366            value = uuid.UUID(value)
367        return value
368
369    def binary_placeholder_sql(self, value: Any) -> str:
370        return (
371            "_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s"
372        )
373
374    def subtract_temporals(
375        self,
376        internal_type: str,
377        lhs: tuple[str, list[Any] | tuple[Any, ...]],
378        rhs: tuple[str, list[Any] | tuple[Any, ...]],
379    ) -> tuple[str, tuple[Any, ...]]:
380        lhs_sql, lhs_params = lhs
381        rhs_sql, rhs_params = rhs
382        if internal_type == "TimeField":
383            if self.connection.mysql_is_mariadb:
384                # MariaDB includes the microsecond component in TIME_TO_SEC as
385                # a decimal. MySQL returns an integer without microseconds.
386                return (
387                    f"CAST((TIME_TO_SEC({lhs_sql}) - TIME_TO_SEC({rhs_sql})) "
388                    "* 1000000 AS SIGNED)"
389                ), (
390                    *lhs_params,
391                    *rhs_params,
392                )
393            return (
394                f"((TIME_TO_SEC({lhs_sql}) * 1000000 + MICROSECOND({lhs_sql})) -"
395                f" (TIME_TO_SEC({rhs_sql}) * 1000000 + MICROSECOND({rhs_sql})))"
396            ), tuple(lhs_params) * 2 + tuple(rhs_params) * 2
397        params = (*rhs_params, *lhs_params)
398        return f"TIMESTAMPDIFF(MICROSECOND, {rhs_sql}, {lhs_sql})", params
399
400    def explain_query_prefix(self, format: str | None = None, **options: Any) -> str:
401        # Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
402        if format and format.upper() == "TEXT":
403            format = "TRADITIONAL"
404        elif (
405            not format and "TREE" in self.connection.features.supported_explain_formats
406        ):
407            # Use TREE by default (if supported) as it's more informative.
408            format = "TREE"
409        analyze = options.pop("analyze", False)
410        prefix = super().explain_query_prefix(format, **options)
411        if analyze and self.connection.features.supports_explain_analyze:
412            # MariaDB uses ANALYZE instead of EXPLAIN ANALYZE.
413            prefix = (
414                "ANALYZE" if self.connection.mysql_is_mariadb else prefix + " ANALYZE"
415            )
416        if format and not (analyze and not self.connection.mysql_is_mariadb):
417            # Only MariaDB supports the analyze option with formats.
418            prefix += f" FORMAT={format}"
419        return prefix
420
421    def regex_lookup(self, lookup_type: str) -> str:
422        # REGEXP_LIKE doesn't exist in MariaDB.
423        if self.connection.mysql_is_mariadb:
424            if lookup_type == "regex":
425                return "%s REGEXP BINARY %s"
426            return "%s REGEXP %s"
427
428        match_option = "c" if lookup_type == "regex" else "i"
429        return f"REGEXP_LIKE(%s, %s, '{match_option}')"
430
431    def insert_statement(self, on_conflict: Any = None) -> str:
432        if on_conflict == OnConflict.IGNORE:
433            return "INSERT IGNORE INTO"
434        return super().insert_statement(on_conflict=on_conflict)
435
436    def lookup_cast(self, lookup_type: str, internal_type: str | None = None) -> str:
437        lookup = "%s"
438        if internal_type == "JSONField":
439            if self.connection.mysql_is_mariadb or lookup_type in (
440                "iexact",
441                "contains",
442                "icontains",
443                "startswith",
444                "istartswith",
445                "endswith",
446                "iendswith",
447                "regex",
448                "iregex",
449            ):
450                lookup = "JSON_UNQUOTE(%s)"
451        return lookup
452
453    def conditional_expression_supported_in_where_clause(self, expression: Any) -> bool:
454        # MySQL ignores indexes with boolean fields unless they're compared
455        # directly to a boolean value.
456        if isinstance(expression, Exists | Lookup):
457            return True
458        if isinstance(expression, ExpressionWrapper) and expression.conditional:
459            return self.conditional_expression_supported_in_where_clause(
460                expression.expression
461            )
462        if getattr(expression, "conditional", False):
463            return False
464        return super().conditional_expression_supported_in_where_clause(expression)
465
466    def on_conflict_suffix_sql(
467        self,
468        fields: list[Field],
469        on_conflict: Any,
470        update_fields: Iterable[str],
471        unique_fields: Iterable[str],
472    ) -> str:
473        if on_conflict == OnConflict.UPDATE:
474            conflict_suffix_sql = "ON DUPLICATE KEY UPDATE %(fields)s"
475            # The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
476            # aliases for the new row and its columns available in MySQL
477            # 8.0.19+.
478            if not self.connection.mysql_is_mariadb:
479                if self.connection.mysql_version >= (8, 0, 19):
480                    conflict_suffix_sql = f"AS new {conflict_suffix_sql}"
481                    field_sql = "%(field)s = new.%(field)s"
482                else:
483                    field_sql = "%(field)s = VALUES(%(field)s)"
484            # Use VALUE() on MariaDB.
485            else:
486                field_sql = "%(field)s = VALUE(%(field)s)"
487
488            fields_str = ", ".join(
489                [
490                    field_sql % {"field": field}
491                    for field in map(self.quote_name, update_fields)
492                ]
493            )
494            return conflict_suffix_sql % {"fields": fields_str}
495        return super().on_conflict_suffix_sql(
496            fields,
497            on_conflict,
498            update_fields,
499            unique_fields,
500        )