Plain is headed towards 1.0! Subscribe for development updates →

  1import uuid
  2
  3from plain.models.backends.base.operations import BaseDatabaseOperations
  4from plain.models.backends.utils import split_tzname_delta
  5from plain.models.constants import OnConflict
  6from plain.models.expressions import Exists, ExpressionWrapper
  7from plain.models.lookups import Lookup
  8from plain.utils import timezone
  9from plain.utils.encoding import force_str
 10from plain.utils.regex_helper import _lazy_re_compile
 11
 12
 13class DatabaseOperations(BaseDatabaseOperations):
 14    compiler_module = "plain.models.backends.mysql.compiler"
 15
 16    # MySQL stores positive fields as UNSIGNED ints.
 17    integer_field_ranges = {
 18        **BaseDatabaseOperations.integer_field_ranges,
 19        "PositiveSmallIntegerField": (0, 65535),
 20        "PositiveIntegerField": (0, 4294967295),
 21        "PositiveBigIntegerField": (0, 18446744073709551615),
 22    }
 23    cast_data_types = {
 24        "AutoField": "signed integer",
 25        "BigAutoField": "signed integer",
 26        "SmallAutoField": "signed integer",
 27        "CharField": "char(%(max_length)s)",
 28        "DecimalField": "decimal(%(max_digits)s, %(decimal_places)s)",
 29        "TextField": "char",
 30        "IntegerField": "signed integer",
 31        "BigIntegerField": "signed integer",
 32        "SmallIntegerField": "signed integer",
 33        "PositiveBigIntegerField": "unsigned integer",
 34        "PositiveIntegerField": "unsigned integer",
 35        "PositiveSmallIntegerField": "unsigned integer",
 36        "DurationField": "signed integer",
 37    }
 38    cast_char_field_without_max_length = "char"
 39    explain_prefix = "EXPLAIN"
 40
 41    # EXTRACT format cannot be passed in parameters.
 42    _extract_format_re = _lazy_re_compile(r"[A-Z_]+")
 43
 44    def date_extract_sql(self, lookup_type, sql, params):
 45        # https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
 46        if lookup_type == "week_day":
 47            # DAYOFWEEK() returns an integer, 1-7, Sunday=1.
 48            return f"DAYOFWEEK({sql})", params
 49        elif lookup_type == "iso_week_day":
 50            # WEEKDAY() returns an integer, 0-6, Monday=0.
 51            return f"WEEKDAY({sql}) + 1", params
 52        elif lookup_type == "week":
 53            # Override the value of default_week_format for consistency with
 54            # other database backends.
 55            # Mode 3: Monday, 1-53, with 4 or more days this year.
 56            return f"WEEK({sql}, 3)", params
 57        elif lookup_type == "iso_year":
 58            # Get the year part from the YEARWEEK function, which returns a
 59            # number as year * 100 + week.
 60            return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params
 61        else:
 62            # EXTRACT returns 1-53 based on ISO-8601 for the week number.
 63            lookup_type = lookup_type.upper()
 64            if not self._extract_format_re.fullmatch(lookup_type):
 65                raise ValueError(f"Invalid loookup type: {lookup_type!r}")
 66            return f"EXTRACT({lookup_type} FROM {sql})", params
 67
 68    def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
 69        sql, params = self._convert_sql_to_tz(sql, params, tzname)
 70        fields = {
 71            "year": "%Y-01-01",
 72            "month": "%Y-%m-01",
 73        }
 74        if lookup_type in fields:
 75            format_str = fields[lookup_type]
 76            return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str)
 77        elif lookup_type == "quarter":
 78            return (
 79                f"MAKEDATE(YEAR({sql}), 1) + "
 80                f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER",
 81                (*params, *params),
 82            )
 83        elif lookup_type == "week":
 84            return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params)
 85        else:
 86            return f"DATE({sql})", params
 87
 88    def _prepare_tzname_delta(self, tzname):
 89        tzname, sign, offset = split_tzname_delta(tzname)
 90        return f"{sign}{offset}" if offset else tzname
 91
 92    def _convert_sql_to_tz(self, sql, params, tzname):
 93        if tzname and self.connection.timezone_name != tzname:
 94            return f"CONVERT_TZ({sql}, %s, %s)", (
 95                *params,
 96                self.connection.timezone_name,
 97                self._prepare_tzname_delta(tzname),
 98            )
 99        return sql, params
100
101    def datetime_cast_date_sql(self, sql, params, tzname):
102        sql, params = self._convert_sql_to_tz(sql, params, tzname)
103        return f"DATE({sql})", params
104
105    def datetime_cast_time_sql(self, sql, params, tzname):
106        sql, params = self._convert_sql_to_tz(sql, params, tzname)
107        return f"TIME({sql})", params
108
109    def datetime_extract_sql(self, lookup_type, sql, params, tzname):
110        sql, params = self._convert_sql_to_tz(sql, params, tzname)
111        return self.date_extract_sql(lookup_type, sql, params)
112
113    def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
114        sql, params = self._convert_sql_to_tz(sql, params, tzname)
115        fields = ["year", "month", "day", "hour", "minute", "second"]
116        format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s")
117        format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
118        if lookup_type == "quarter":
119            return (
120                f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + "
121                f"INTERVAL QUARTER({sql}) QUARTER - "
122                f"INTERVAL 1 QUARTER, %s) AS DATETIME)"
123            ), (*params, *params, "%Y-%m-01 00:00:00")
124        if lookup_type == "week":
125            return (
126                f"CAST(DATE_FORMAT("
127                f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)"
128            ), (*params, *params, "%Y-%m-%d 00:00:00")
129        try:
130            i = fields.index(lookup_type) + 1
131        except ValueError:
132            pass
133        else:
134            format_str = "".join(format[:i] + format_def[i:])
135            return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str)
136        return sql, params
137
138    def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
139        sql, params = self._convert_sql_to_tz(sql, params, tzname)
140        fields = {
141            "hour": "%H:00:00",
142            "minute": "%H:%i:00",
143            "second": "%H:%i:%s",
144        }
145        if lookup_type in fields:
146            format_str = fields[lookup_type]
147            return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str)
148        else:
149            return f"TIME({sql})", params
150
151    def fetch_returned_insert_rows(self, cursor):
152        """
153        Given a cursor object that has just performed an INSERT...RETURNING
154        statement into a table, return the tuple of returned data.
155        """
156        return cursor.fetchall()
157
158    def format_for_duration_arithmetic(self, sql):
159        return f"INTERVAL {sql} MICROSECOND"
160
161    def force_no_ordering(self):
162        """
163        "ORDER BY NULL" prevents MySQL from implicitly ordering by grouped
164        columns. If no ordering would otherwise be applied, we don't want any
165        implicit sorting going on.
166        """
167        return [(None, ("NULL", [], False))]
168
169    def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
170        return value
171
172    def last_executed_query(self, cursor, sql, params):
173        # With MySQLdb, cursor objects have an (undocumented) "_executed"
174        # attribute where the exact query sent to the database is saved.
175        # See MySQLdb/cursors.py in the source distribution.
176        # MySQLdb returns string, PyMySQL bytes.
177        return force_str(getattr(cursor, "_executed", None), errors="replace")
178
179    def no_limit_value(self):
180        # 2**64 - 1, as recommended by the MySQL documentation
181        return 18446744073709551615
182
183    def quote_name(self, name):
184        if name.startswith("`") and name.endswith("`"):
185            return name  # Quoting once is enough.
186        return f"`{name}`"
187
188    def return_insert_columns(self, fields):
189        # MySQL and MariaDB < 10.5.0 don't support an INSERT...RETURNING
190        # statement.
191        if not fields:
192            return "", ()
193        columns = [
194            f"{self.quote_name(field.model._meta.db_table)}.{self.quote_name(field.column)}"
195            for field in fields
196        ]
197        return "RETURNING {}".format(", ".join(columns)), ()
198
199    def validate_autopk_value(self, value):
200        # Zero in AUTO_INCREMENT field does not work without the
201        # NO_AUTO_VALUE_ON_ZERO SQL mode.
202        if value == 0 and not self.connection.features.allows_auto_pk_0:
203            raise ValueError(
204                "The database backend does not accept 0 as a value for AutoField."
205            )
206        return value
207
208    def adapt_datetimefield_value(self, value):
209        if value is None:
210            return None
211
212        # Expression values are adapted by the database.
213        if hasattr(value, "resolve_expression"):
214            return value
215
216        # MySQL doesn't support tz-aware datetimes
217        if timezone.is_aware(value):
218            value = timezone.make_naive(value, self.connection.timezone)
219        return str(value)
220
221    def adapt_timefield_value(self, value):
222        if value is None:
223            return None
224
225        # Expression values are adapted by the database.
226        if hasattr(value, "resolve_expression"):
227            return value
228
229        # MySQL doesn't support tz-aware times
230        if timezone.is_aware(value):
231            raise ValueError("MySQL backend does not support timezone-aware times.")
232
233        return value.isoformat(timespec="microseconds")
234
235    def max_name_length(self):
236        return 64
237
238    def pk_default_value(self):
239        return "NULL"
240
241    def bulk_insert_sql(self, fields, placeholder_rows):
242        placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
243        values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
244        return "VALUES " + values_sql
245
246    def combine_expression(self, connector, sub_expressions):
247        if connector == "^":
248            return "POW({})".format(",".join(sub_expressions))
249        # Convert the result to a signed integer since MySQL's binary operators
250        # return an unsigned integer.
251        elif connector in ("&", "|", "<<", "#"):
252            connector = "^" if connector == "#" else connector
253            return f"CONVERT({connector.join(sub_expressions)}, SIGNED)"
254        elif connector == ">>":
255            lhs, rhs = sub_expressions
256            return f"FLOOR({lhs} / POW(2, {rhs}))"
257        return super().combine_expression(connector, sub_expressions)
258
259    def get_db_converters(self, expression):
260        converters = super().get_db_converters(expression)
261        internal_type = expression.output_field.get_internal_type()
262        if internal_type == "BooleanField":
263            converters.append(self.convert_booleanfield_value)
264        elif internal_type == "DateTimeField":
265            converters.append(self.convert_datetimefield_value)
266        elif internal_type == "UUIDField":
267            converters.append(self.convert_uuidfield_value)
268        return converters
269
270    def convert_booleanfield_value(self, value, expression, connection):
271        if value in (0, 1):
272            value = bool(value)
273        return value
274
275    def convert_datetimefield_value(self, value, expression, connection):
276        if value is not None:
277            value = timezone.make_aware(value, self.connection.timezone)
278        return value
279
280    def convert_uuidfield_value(self, value, expression, connection):
281        if value is not None:
282            value = uuid.UUID(value)
283        return value
284
285    def binary_placeholder_sql(self, value):
286        return (
287            "_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s"
288        )
289
290    def subtract_temporals(self, internal_type, lhs, rhs):
291        lhs_sql, lhs_params = lhs
292        rhs_sql, rhs_params = rhs
293        if internal_type == "TimeField":
294            if self.connection.mysql_is_mariadb:
295                # MariaDB includes the microsecond component in TIME_TO_SEC as
296                # a decimal. MySQL returns an integer without microseconds.
297                return (
298                    f"CAST((TIME_TO_SEC({lhs_sql}) - TIME_TO_SEC({rhs_sql})) "
299                    "* 1000000 AS SIGNED)"
300                ), (
301                    *lhs_params,
302                    *rhs_params,
303                )
304            return (
305                f"((TIME_TO_SEC({lhs_sql}) * 1000000 + MICROSECOND({lhs_sql})) -"
306                f" (TIME_TO_SEC({rhs_sql}) * 1000000 + MICROSECOND({rhs_sql})))"
307            ), tuple(lhs_params) * 2 + tuple(rhs_params) * 2
308        params = (*rhs_params, *lhs_params)
309        return f"TIMESTAMPDIFF(MICROSECOND, {rhs_sql}, {lhs_sql})", params
310
311    def explain_query_prefix(self, format=None, **options):
312        # Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
313        if format and format.upper() == "TEXT":
314            format = "TRADITIONAL"
315        elif (
316            not format and "TREE" in self.connection.features.supported_explain_formats
317        ):
318            # Use TREE by default (if supported) as it's more informative.
319            format = "TREE"
320        analyze = options.pop("analyze", False)
321        prefix = super().explain_query_prefix(format, **options)
322        if analyze and self.connection.features.supports_explain_analyze:
323            # MariaDB uses ANALYZE instead of EXPLAIN ANALYZE.
324            prefix = (
325                "ANALYZE" if self.connection.mysql_is_mariadb else prefix + " ANALYZE"
326            )
327        if format and not (analyze and not self.connection.mysql_is_mariadb):
328            # Only MariaDB supports the analyze option with formats.
329            prefix += f" FORMAT={format}"
330        return prefix
331
332    def regex_lookup(self, lookup_type):
333        # REGEXP_LIKE doesn't exist in MariaDB.
334        if self.connection.mysql_is_mariadb:
335            if lookup_type == "regex":
336                return "%s REGEXP BINARY %s"
337            return "%s REGEXP %s"
338
339        match_option = "c" if lookup_type == "regex" else "i"
340        return f"REGEXP_LIKE(%s, %s, '{match_option}')"
341
342    def insert_statement(self, on_conflict=None):
343        if on_conflict == OnConflict.IGNORE:
344            return "INSERT IGNORE INTO"
345        return super().insert_statement(on_conflict=on_conflict)
346
347    def lookup_cast(self, lookup_type, internal_type=None):
348        lookup = "%s"
349        if internal_type == "JSONField":
350            if self.connection.mysql_is_mariadb or lookup_type in (
351                "iexact",
352                "contains",
353                "icontains",
354                "startswith",
355                "istartswith",
356                "endswith",
357                "iendswith",
358                "regex",
359                "iregex",
360            ):
361                lookup = "JSON_UNQUOTE(%s)"
362        return lookup
363
364    def conditional_expression_supported_in_where_clause(self, expression):
365        # MySQL ignores indexes with boolean fields unless they're compared
366        # directly to a boolean value.
367        if isinstance(expression, Exists | Lookup):
368            return True
369        if isinstance(expression, ExpressionWrapper) and expression.conditional:
370            return self.conditional_expression_supported_in_where_clause(
371                expression.expression
372            )
373        if getattr(expression, "conditional", False):
374            return False
375        return super().conditional_expression_supported_in_where_clause(expression)
376
377    def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
378        if on_conflict == OnConflict.UPDATE:
379            conflict_suffix_sql = "ON DUPLICATE KEY UPDATE %(fields)s"
380            # The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
381            # aliases for the new row and its columns available in MySQL
382            # 8.0.19+.
383            if not self.connection.mysql_is_mariadb:
384                if self.connection.mysql_version >= (8, 0, 19):
385                    conflict_suffix_sql = f"AS new {conflict_suffix_sql}"
386                    field_sql = "%(field)s = new.%(field)s"
387                else:
388                    field_sql = "%(field)s = VALUES(%(field)s)"
389            # Use VALUE() on MariaDB.
390            else:
391                field_sql = "%(field)s = VALUE(%(field)s)"
392
393            fields = ", ".join(
394                [
395                    field_sql % {"field": field}
396                    for field in map(self.quote_name, update_fields)
397                ]
398            )
399            return conflict_suffix_sql % {"fields": fields}
400        return super().on_conflict_suffix_sql(
401            fields,
402            on_conflict,
403            update_fields,
404            unique_fields,
405        )