Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import datetime
  4import uuid
  5from typing import TYPE_CHECKING, Any
  6
  7from plain.models.backends.base.operations import BaseDatabaseOperations
  8from plain.models.backends.utils import split_tzname_delta
  9from plain.models.constants import OnConflict
 10from plain.models.expressions import Exists, ExpressionWrapper
 11from plain.models.lookups import Lookup
 12from plain.utils import timezone
 13from plain.utils.encoding import force_str
 14from plain.utils.regex_helper import _lazy_re_compile
 15
 16if TYPE_CHECKING:
 17    from plain.models.backends.base.base import BaseDatabaseWrapper
 18
 19
 20class DatabaseOperations(BaseDatabaseOperations):
 21    compiler_module = "plain.models.backends.mysql.compiler"
 22
 23    # MySQL stores positive fields as UNSIGNED ints.
 24    integer_field_ranges = {
 25        **BaseDatabaseOperations.integer_field_ranges,
 26        "PositiveSmallIntegerField": (0, 65535),
 27        "PositiveIntegerField": (0, 4294967295),
 28        "PositiveBigIntegerField": (0, 18446744073709551615),
 29    }
 30    cast_data_types = {
 31        "PrimaryKeyField": "signed integer",
 32        "CharField": "char(%(max_length)s)",
 33        "DecimalField": "decimal(%(max_digits)s, %(decimal_places)s)",
 34        "TextField": "char",
 35        "IntegerField": "signed integer",
 36        "BigIntegerField": "signed integer",
 37        "SmallIntegerField": "signed integer",
 38        "PositiveBigIntegerField": "unsigned integer",
 39        "PositiveIntegerField": "unsigned integer",
 40        "PositiveSmallIntegerField": "unsigned integer",
 41        "DurationField": "signed integer",
 42    }
 43    cast_char_field_without_max_length = "char"
 44    explain_prefix = "EXPLAIN"
 45
 46    # EXTRACT format cannot be passed in parameters.
 47    _extract_format_re = _lazy_re_compile(r"[A-Z_]+")
 48
 49    def date_extract_sql(
 50        self, lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
 51    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
 52        # https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
 53        if lookup_type == "week_day":
 54            # DAYOFWEEK() returns an integer, 1-7, Sunday=1.
 55            return f"DAYOFWEEK({sql})", params
 56        elif lookup_type == "iso_week_day":
 57            # WEEKDAY() returns an integer, 0-6, Monday=0.
 58            return f"WEEKDAY({sql}) + 1", params
 59        elif lookup_type == "week":
 60            # Override the value of default_week_format for consistency with
 61            # other database backends.
 62            # Mode 3: Monday, 1-53, with 4 or more days this year.
 63            return f"WEEK({sql}, 3)", params
 64        elif lookup_type == "iso_year":
 65            # Get the year part from the YEARWEEK function, which returns a
 66            # number as year * 100 + week.
 67            return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params
 68        else:
 69            # EXTRACT returns 1-53 based on ISO-8601 for the week number.
 70            lookup_type = lookup_type.upper()
 71            if not self._extract_format_re.fullmatch(lookup_type):
 72                raise ValueError(f"Invalid loookup type: {lookup_type!r}")
 73            return f"EXTRACT({lookup_type} FROM {sql})", params
 74
 75    def date_trunc_sql(
 76        self,
 77        lookup_type: str,
 78        sql: str,
 79        params: list[Any] | tuple[Any, ...],
 80        tzname: str | None = None,
 81    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
 82        sql, params = self._convert_sql_to_tz(sql, params, tzname)
 83        fields = {
 84            "year": "%Y-01-01",
 85            "month": "%Y-%m-01",
 86        }
 87        if lookup_type in fields:
 88            format_str = fields[lookup_type]
 89            return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str)
 90        elif lookup_type == "quarter":
 91            return (
 92                f"MAKEDATE(YEAR({sql}), 1) + "
 93                f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER",
 94                (*params, *params),
 95            )
 96        elif lookup_type == "week":
 97            return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params)
 98        else:
 99            return f"DATE({sql})", params
100
101    def _prepare_tzname_delta(self, tzname: str) -> str:
102        tzname, sign, offset = split_tzname_delta(tzname)
103        return f"{sign}{offset}" if offset else tzname
104
105    def _convert_sql_to_tz(
106        self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
107    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
108        if tzname and self.connection.timezone_name != tzname:
109            return f"CONVERT_TZ({sql}, %s, %s)", (
110                *params,
111                self.connection.timezone_name,
112                self._prepare_tzname_delta(tzname),
113            )
114        return sql, params
115
116    def datetime_cast_date_sql(
117        self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
118    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
119        sql, params = self._convert_sql_to_tz(sql, params, tzname)
120        return f"DATE({sql})", params
121
122    def datetime_cast_time_sql(
123        self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
124    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
125        sql, params = self._convert_sql_to_tz(sql, params, tzname)
126        return f"TIME({sql})", params
127
128    def datetime_extract_sql(
129        self,
130        lookup_type: str,
131        sql: str,
132        params: list[Any] | tuple[Any, ...],
133        tzname: str | None,
134    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
135        sql, params = self._convert_sql_to_tz(sql, params, tzname)
136        return self.date_extract_sql(lookup_type, sql, params)
137
138    def datetime_trunc_sql(
139        self,
140        lookup_type: str,
141        sql: str,
142        params: list[Any] | tuple[Any, ...],
143        tzname: str | None,
144    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
145        sql, params = self._convert_sql_to_tz(sql, params, tzname)
146        fields = ["year", "month", "day", "hour", "minute", "second"]
147        format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s")
148        format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
149        if lookup_type == "quarter":
150            return (
151                f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + "
152                f"INTERVAL QUARTER({sql}) QUARTER - "
153                f"INTERVAL 1 QUARTER, %s) AS DATETIME)"
154            ), (*params, *params, "%Y-%m-01 00:00:00")
155        if lookup_type == "week":
156            return (
157                f"CAST(DATE_FORMAT("
158                f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)"
159            ), (*params, *params, "%Y-%m-%d 00:00:00")
160        try:
161            i = fields.index(lookup_type) + 1
162        except ValueError:
163            pass
164        else:
165            format_str = "".join(format[:i] + format_def[i:])
166            return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str)
167        return sql, params
168
169    def time_trunc_sql(
170        self,
171        lookup_type: str,
172        sql: str,
173        params: list[Any] | tuple[Any, ...],
174        tzname: str | None = None,
175    ) -> tuple[str, list[Any] | tuple[Any, ...]]:
176        sql, params = self._convert_sql_to_tz(sql, params, tzname)
177        fields = {
178            "hour": "%H:00:00",
179            "minute": "%H:%i:00",
180            "second": "%H:%i:%s",
181        }
182        if lookup_type in fields:
183            format_str = fields[lookup_type]
184            return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str)
185        else:
186            return f"TIME({sql})", params
187
188    def fetch_returned_insert_rows(self, cursor: Any) -> list[Any]:
189        """
190        Given a cursor object that has just performed an INSERT...RETURNING
191        statement into a table, return the tuple of returned data.
192        """
193        return cursor.fetchall()
194
195    def format_for_duration_arithmetic(self, sql: str) -> str:
196        return f"INTERVAL {sql} MICROSECOND"
197
198    def force_no_ordering(self) -> list[tuple[None, tuple[str, list[Any], bool]]]:
199        """
200        "ORDER BY NULL" prevents MySQL from implicitly ordering by grouped
201        columns. If no ordering would otherwise be applied, we don't want any
202        implicit sorting going on.
203        """
204        return [(None, ("NULL", [], False))]
205
206    def adapt_decimalfield_value(
207        self,
208        value: Any,
209        max_digits: int | None = None,
210        decimal_places: int | None = None,
211    ) -> Any:
212        return value
213
214    def last_executed_query(self, cursor: Any, sql: str, params: Any) -> str | None:
215        # With MySQLdb, cursor objects have an (undocumented) "_executed"
216        # attribute where the exact query sent to the database is saved.
217        # See MySQLdb/cursors.py in the source distribution.
218        # MySQLdb returns string, PyMySQL bytes.
219        return force_str(getattr(cursor, "_executed", None), errors="replace")
220
221    def no_limit_value(self) -> int:
222        # 2**64 - 1, as recommended by the MySQL documentation
223        return 18446744073709551615
224
225    def quote_name(self, name: str) -> str:
226        if name.startswith("`") and name.endswith("`"):
227            return name  # Quoting once is enough.
228        return f"`{name}`"
229
230    def return_insert_columns(self, fields: list[Any]) -> tuple[str, tuple[Any, ...]]:
231        # MySQL and MariaDB < 10.5.0 don't support an INSERT...RETURNING
232        # statement.
233        if not fields:
234            return "", ()
235        columns = [
236            f"{self.quote_name(field.model.model_options.db_table)}.{self.quote_name(field.column)}"
237            for field in fields
238        ]
239        return "RETURNING {}".format(", ".join(columns)), ()
240
241    def validate_autopk_value(self, value: int) -> int:
242        # Zero in AUTO_INCREMENT field does not work without the
243        # NO_AUTO_VALUE_ON_ZERO SQL mode.
244        if value == 0 and not self.connection.features.allows_auto_pk_0:
245            raise ValueError(
246                "The database backend does not accept 0 as a value for PrimaryKeyField."
247            )
248        return value
249
250    def adapt_datetimefield_value(
251        self, value: datetime.datetime | Any | None
252    ) -> str | Any | None:
253        if value is None:
254            return None
255
256        # Expression values are adapted by the database.
257        if hasattr(value, "resolve_expression"):
258            return value
259
260        # MySQL doesn't support tz-aware datetimes
261        if timezone.is_aware(value):
262            value = timezone.make_naive(value, self.connection.timezone)
263        return str(value)
264
265    def adapt_timefield_value(
266        self, value: datetime.time | Any | None
267    ) -> str | Any | None:
268        if value is None:
269            return None
270
271        # Expression values are adapted by the database.
272        if hasattr(value, "resolve_expression"):
273            return value
274
275        # MySQL doesn't support tz-aware times
276        if timezone.is_aware(value):  # type: ignore[arg-type]
277            raise ValueError("MySQL backend does not support timezone-aware times.")
278
279        return value.isoformat(timespec="microseconds")
280
281    def max_name_length(self) -> int:
282        return 64
283
284    def pk_default_value(self) -> str:
285        return "NULL"
286
287    def bulk_insert_sql(
288        self, fields: list[Any], placeholder_rows: list[list[str]]
289    ) -> str:
290        placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
291        values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
292        return "VALUES " + values_sql
293
294    def combine_expression(self, connector: str, sub_expressions: list[str]) -> str:
295        if connector == "^":
296            return "POW({})".format(",".join(sub_expressions))
297        # Convert the result to a signed integer since MySQL's binary operators
298        # return an unsigned integer.
299        elif connector in ("&", "|", "<<", "#"):
300            connector = "^" if connector == "#" else connector
301            return f"CONVERT({connector.join(sub_expressions)}, SIGNED)"
302        elif connector == ">>":
303            lhs, rhs = sub_expressions
304            return f"FLOOR({lhs} / POW(2, {rhs}))"
305        return super().combine_expression(connector, sub_expressions)
306
307    def get_db_converters(self, expression: Any) -> list[Any]:
308        converters = super().get_db_converters(expression)
309        internal_type = expression.output_field.get_internal_type()
310        if internal_type == "BooleanField":
311            converters.append(self.convert_booleanfield_value)
312        elif internal_type == "DateTimeField":
313            converters.append(self.convert_datetimefield_value)
314        elif internal_type == "UUIDField":
315            converters.append(self.convert_uuidfield_value)
316        return converters
317
318    def convert_booleanfield_value(
319        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
320    ) -> Any:
321        if value in (0, 1):
322            value = bool(value)
323        return value
324
325    def convert_datetimefield_value(
326        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
327    ) -> datetime.datetime | None:
328        if value is not None:
329            value = timezone.make_aware(value, self.connection.timezone)
330        return value
331
332    def convert_uuidfield_value(
333        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
334    ) -> uuid.UUID | None:
335        if value is not None:
336            value = uuid.UUID(value)
337        return value
338
339    def binary_placeholder_sql(self, value: Any) -> str:
340        return (
341            "_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s"
342        )
343
344    def subtract_temporals(
345        self,
346        internal_type: str,
347        lhs: tuple[str, list[Any] | tuple[Any, ...]],
348        rhs: tuple[str, list[Any] | tuple[Any, ...]],
349    ) -> tuple[str, tuple[Any, ...]]:
350        lhs_sql, lhs_params = lhs
351        rhs_sql, rhs_params = rhs
352        if internal_type == "TimeField":
353            if self.connection.mysql_is_mariadb:
354                # MariaDB includes the microsecond component in TIME_TO_SEC as
355                # a decimal. MySQL returns an integer without microseconds.
356                return (
357                    f"CAST((TIME_TO_SEC({lhs_sql}) - TIME_TO_SEC({rhs_sql})) "
358                    "* 1000000 AS SIGNED)"
359                ), (
360                    *lhs_params,
361                    *rhs_params,
362                )
363            return (
364                f"((TIME_TO_SEC({lhs_sql}) * 1000000 + MICROSECOND({lhs_sql})) -"
365                f" (TIME_TO_SEC({rhs_sql}) * 1000000 + MICROSECOND({rhs_sql})))"
366            ), tuple(lhs_params) * 2 + tuple(rhs_params) * 2
367        params = (*rhs_params, *lhs_params)
368        return f"TIMESTAMPDIFF(MICROSECOND, {rhs_sql}, {lhs_sql})", params
369
370    def explain_query_prefix(self, format: str | None = None, **options: Any) -> str:
371        # Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
372        if format and format.upper() == "TEXT":
373            format = "TRADITIONAL"
374        elif (
375            not format and "TREE" in self.connection.features.supported_explain_formats
376        ):
377            # Use TREE by default (if supported) as it's more informative.
378            format = "TREE"
379        analyze = options.pop("analyze", False)
380        prefix = super().explain_query_prefix(format, **options)
381        if analyze and self.connection.features.supports_explain_analyze:
382            # MariaDB uses ANALYZE instead of EXPLAIN ANALYZE.
383            prefix = (
384                "ANALYZE" if self.connection.mysql_is_mariadb else prefix + " ANALYZE"
385            )
386        if format and not (analyze and not self.connection.mysql_is_mariadb):
387            # Only MariaDB supports the analyze option with formats.
388            prefix += f" FORMAT={format}"
389        return prefix
390
391    def regex_lookup(self, lookup_type: str) -> str:
392        # REGEXP_LIKE doesn't exist in MariaDB.
393        if self.connection.mysql_is_mariadb:
394            if lookup_type == "regex":
395                return "%s REGEXP BINARY %s"
396            return "%s REGEXP %s"
397
398        match_option = "c" if lookup_type == "regex" else "i"
399        return f"REGEXP_LIKE(%s, %s, '{match_option}')"
400
401    def insert_statement(self, on_conflict: Any = None) -> str:
402        if on_conflict == OnConflict.IGNORE:
403            return "INSERT IGNORE INTO"
404        return super().insert_statement(on_conflict=on_conflict)
405
406    def lookup_cast(self, lookup_type: str, internal_type: str | None = None) -> str:
407        lookup = "%s"
408        if internal_type == "JSONField":
409            if self.connection.mysql_is_mariadb or lookup_type in (
410                "iexact",
411                "contains",
412                "icontains",
413                "startswith",
414                "istartswith",
415                "endswith",
416                "iendswith",
417                "regex",
418                "iregex",
419            ):
420                lookup = "JSON_UNQUOTE(%s)"
421        return lookup
422
423    def conditional_expression_supported_in_where_clause(self, expression: Any) -> bool:
424        # MySQL ignores indexes with boolean fields unless they're compared
425        # directly to a boolean value.
426        if isinstance(expression, Exists | Lookup):
427            return True
428        if isinstance(expression, ExpressionWrapper) and expression.conditional:
429            return self.conditional_expression_supported_in_where_clause(
430                expression.expression
431            )
432        if getattr(expression, "conditional", False):
433            return False
434        return super().conditional_expression_supported_in_where_clause(expression)
435
436    def on_conflict_suffix_sql(
437        self,
438        fields: list[Any],
439        on_conflict: Any,
440        update_fields: list[Any],
441        unique_fields: list[Any],
442    ) -> str:
443        if on_conflict == OnConflict.UPDATE:
444            conflict_suffix_sql = "ON DUPLICATE KEY UPDATE %(fields)s"
445            # The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
446            # aliases for the new row and its columns available in MySQL
447            # 8.0.19+.
448            if not self.connection.mysql_is_mariadb:
449                if self.connection.mysql_version >= (8, 0, 19):
450                    conflict_suffix_sql = f"AS new {conflict_suffix_sql}"
451                    field_sql = "%(field)s = new.%(field)s"
452                else:
453                    field_sql = "%(field)s = VALUES(%(field)s)"
454            # Use VALUE() on MariaDB.
455            else:
456                field_sql = "%(field)s = VALUE(%(field)s)"
457
458            fields_str = ", ".join(
459                [
460                    field_sql % {"field": field}
461                    for field in map(self.quote_name, update_fields)
462                ]
463            )
464            return conflict_suffix_sql % {"fields": fields_str}
465        return super().on_conflict_suffix_sql(
466            fields,
467            on_conflict,
468            update_fields,
469            unique_fields,
470        )