Plain is headed towards 1.0! Subscribe for development updates →

  1import uuid
  2
  3from plain.models.backends.base.operations import BaseDatabaseOperations
  4from plain.models.backends.utils import split_tzname_delta
  5from plain.models.constants import OnConflict
  6from plain.models.expressions import Exists, ExpressionWrapper
  7from plain.models.lookups import Lookup
  8from plain.utils import timezone
  9from plain.utils.encoding import force_str
 10from plain.utils.regex_helper import _lazy_re_compile
 11
 12
 13class DatabaseOperations(BaseDatabaseOperations):
 14    compiler_module = "plain.models.backends.mysql.compiler"
 15
 16    # MySQL stores positive fields as UNSIGNED ints.
 17    integer_field_ranges = {
 18        **BaseDatabaseOperations.integer_field_ranges,
 19        "PositiveSmallIntegerField": (0, 65535),
 20        "PositiveIntegerField": (0, 4294967295),
 21        "PositiveBigIntegerField": (0, 18446744073709551615),
 22    }
 23    cast_data_types = {
 24        "PrimaryKeyField": "signed integer",
 25        "CharField": "char(%(max_length)s)",
 26        "DecimalField": "decimal(%(max_digits)s, %(decimal_places)s)",
 27        "TextField": "char",
 28        "IntegerField": "signed integer",
 29        "BigIntegerField": "signed integer",
 30        "SmallIntegerField": "signed integer",
 31        "PositiveBigIntegerField": "unsigned integer",
 32        "PositiveIntegerField": "unsigned integer",
 33        "PositiveSmallIntegerField": "unsigned integer",
 34        "DurationField": "signed integer",
 35    }
 36    cast_char_field_without_max_length = "char"
 37    explain_prefix = "EXPLAIN"
 38
 39    # EXTRACT format cannot be passed in parameters.
 40    _extract_format_re = _lazy_re_compile(r"[A-Z_]+")
 41
 42    def date_extract_sql(self, lookup_type, sql, params):
 43        # https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
 44        if lookup_type == "week_day":
 45            # DAYOFWEEK() returns an integer, 1-7, Sunday=1.
 46            return f"DAYOFWEEK({sql})", params
 47        elif lookup_type == "iso_week_day":
 48            # WEEKDAY() returns an integer, 0-6, Monday=0.
 49            return f"WEEKDAY({sql}) + 1", params
 50        elif lookup_type == "week":
 51            # Override the value of default_week_format for consistency with
 52            # other database backends.
 53            # Mode 3: Monday, 1-53, with 4 or more days this year.
 54            return f"WEEK({sql}, 3)", params
 55        elif lookup_type == "iso_year":
 56            # Get the year part from the YEARWEEK function, which returns a
 57            # number as year * 100 + week.
 58            return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params
 59        else:
 60            # EXTRACT returns 1-53 based on ISO-8601 for the week number.
 61            lookup_type = lookup_type.upper()
 62            if not self._extract_format_re.fullmatch(lookup_type):
 63                raise ValueError(f"Invalid loookup type: {lookup_type!r}")
 64            return f"EXTRACT({lookup_type} FROM {sql})", params
 65
 66    def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
 67        sql, params = self._convert_sql_to_tz(sql, params, tzname)
 68        fields = {
 69            "year": "%Y-01-01",
 70            "month": "%Y-%m-01",
 71        }
 72        if lookup_type in fields:
 73            format_str = fields[lookup_type]
 74            return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str)
 75        elif lookup_type == "quarter":
 76            return (
 77                f"MAKEDATE(YEAR({sql}), 1) + "
 78                f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER",
 79                (*params, *params),
 80            )
 81        elif lookup_type == "week":
 82            return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params)
 83        else:
 84            return f"DATE({sql})", params
 85
 86    def _prepare_tzname_delta(self, tzname):
 87        tzname, sign, offset = split_tzname_delta(tzname)
 88        return f"{sign}{offset}" if offset else tzname
 89
 90    def _convert_sql_to_tz(self, sql, params, tzname):
 91        if tzname and self.connection.timezone_name != tzname:
 92            return f"CONVERT_TZ({sql}, %s, %s)", (
 93                *params,
 94                self.connection.timezone_name,
 95                self._prepare_tzname_delta(tzname),
 96            )
 97        return sql, params
 98
 99    def datetime_cast_date_sql(self, sql, params, tzname):
100        sql, params = self._convert_sql_to_tz(sql, params, tzname)
101        return f"DATE({sql})", params
102
103    def datetime_cast_time_sql(self, sql, params, tzname):
104        sql, params = self._convert_sql_to_tz(sql, params, tzname)
105        return f"TIME({sql})", params
106
107    def datetime_extract_sql(self, lookup_type, sql, params, tzname):
108        sql, params = self._convert_sql_to_tz(sql, params, tzname)
109        return self.date_extract_sql(lookup_type, sql, params)
110
111    def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
112        sql, params = self._convert_sql_to_tz(sql, params, tzname)
113        fields = ["year", "month", "day", "hour", "minute", "second"]
114        format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s")
115        format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
116        if lookup_type == "quarter":
117            return (
118                f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + "
119                f"INTERVAL QUARTER({sql}) QUARTER - "
120                f"INTERVAL 1 QUARTER, %s) AS DATETIME)"
121            ), (*params, *params, "%Y-%m-01 00:00:00")
122        if lookup_type == "week":
123            return (
124                f"CAST(DATE_FORMAT("
125                f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)"
126            ), (*params, *params, "%Y-%m-%d 00:00:00")
127        try:
128            i = fields.index(lookup_type) + 1
129        except ValueError:
130            pass
131        else:
132            format_str = "".join(format[:i] + format_def[i:])
133            return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str)
134        return sql, params
135
136    def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
137        sql, params = self._convert_sql_to_tz(sql, params, tzname)
138        fields = {
139            "hour": "%H:00:00",
140            "minute": "%H:%i:00",
141            "second": "%H:%i:%s",
142        }
143        if lookup_type in fields:
144            format_str = fields[lookup_type]
145            return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str)
146        else:
147            return f"TIME({sql})", params
148
149    def fetch_returned_insert_rows(self, cursor):
150        """
151        Given a cursor object that has just performed an INSERT...RETURNING
152        statement into a table, return the tuple of returned data.
153        """
154        return cursor.fetchall()
155
156    def format_for_duration_arithmetic(self, sql):
157        return f"INTERVAL {sql} MICROSECOND"
158
159    def force_no_ordering(self):
160        """
161        "ORDER BY NULL" prevents MySQL from implicitly ordering by grouped
162        columns. If no ordering would otherwise be applied, we don't want any
163        implicit sorting going on.
164        """
165        return [(None, ("NULL", [], False))]
166
167    def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
168        return value
169
170    def last_executed_query(self, cursor, sql, params):
171        # With MySQLdb, cursor objects have an (undocumented) "_executed"
172        # attribute where the exact query sent to the database is saved.
173        # See MySQLdb/cursors.py in the source distribution.
174        # MySQLdb returns string, PyMySQL bytes.
175        return force_str(getattr(cursor, "_executed", None), errors="replace")
176
177    def no_limit_value(self):
178        # 2**64 - 1, as recommended by the MySQL documentation
179        return 18446744073709551615
180
181    def quote_name(self, name):
182        if name.startswith("`") and name.endswith("`"):
183            return name  # Quoting once is enough.
184        return f"`{name}`"
185
186    def return_insert_columns(self, fields):
187        # MySQL and MariaDB < 10.5.0 don't support an INSERT...RETURNING
188        # statement.
189        if not fields:
190            return "", ()
191        columns = [
192            f"{self.quote_name(field.model._meta.db_table)}.{self.quote_name(field.column)}"
193            for field in fields
194        ]
195        return "RETURNING {}".format(", ".join(columns)), ()
196
197    def validate_autopk_value(self, value):
198        # Zero in AUTO_INCREMENT field does not work without the
199        # NO_AUTO_VALUE_ON_ZERO SQL mode.
200        if value == 0 and not self.connection.features.allows_auto_pk_0:
201            raise ValueError(
202                "The database backend does not accept 0 as a value for PrimaryKeyField."
203            )
204        return value
205
206    def adapt_datetimefield_value(self, value):
207        if value is None:
208            return None
209
210        # Expression values are adapted by the database.
211        if hasattr(value, "resolve_expression"):
212            return value
213
214        # MySQL doesn't support tz-aware datetimes
215        if timezone.is_aware(value):
216            value = timezone.make_naive(value, self.connection.timezone)
217        return str(value)
218
219    def adapt_timefield_value(self, value):
220        if value is None:
221            return None
222
223        # Expression values are adapted by the database.
224        if hasattr(value, "resolve_expression"):
225            return value
226
227        # MySQL doesn't support tz-aware times
228        if timezone.is_aware(value):
229            raise ValueError("MySQL backend does not support timezone-aware times.")
230
231        return value.isoformat(timespec="microseconds")
232
233    def max_name_length(self):
234        return 64
235
236    def pk_default_value(self):
237        return "NULL"
238
239    def bulk_insert_sql(self, fields, placeholder_rows):
240        placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
241        values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
242        return "VALUES " + values_sql
243
244    def combine_expression(self, connector, sub_expressions):
245        if connector == "^":
246            return "POW({})".format(",".join(sub_expressions))
247        # Convert the result to a signed integer since MySQL's binary operators
248        # return an unsigned integer.
249        elif connector in ("&", "|", "<<", "#"):
250            connector = "^" if connector == "#" else connector
251            return f"CONVERT({connector.join(sub_expressions)}, SIGNED)"
252        elif connector == ">>":
253            lhs, rhs = sub_expressions
254            return f"FLOOR({lhs} / POW(2, {rhs}))"
255        return super().combine_expression(connector, sub_expressions)
256
257    def get_db_converters(self, expression):
258        converters = super().get_db_converters(expression)
259        internal_type = expression.output_field.get_internal_type()
260        if internal_type == "BooleanField":
261            converters.append(self.convert_booleanfield_value)
262        elif internal_type == "DateTimeField":
263            converters.append(self.convert_datetimefield_value)
264        elif internal_type == "UUIDField":
265            converters.append(self.convert_uuidfield_value)
266        return converters
267
268    def convert_booleanfield_value(self, value, expression, connection):
269        if value in (0, 1):
270            value = bool(value)
271        return value
272
273    def convert_datetimefield_value(self, value, expression, connection):
274        if value is not None:
275            value = timezone.make_aware(value, self.connection.timezone)
276        return value
277
278    def convert_uuidfield_value(self, value, expression, connection):
279        if value is not None:
280            value = uuid.UUID(value)
281        return value
282
283    def binary_placeholder_sql(self, value):
284        return (
285            "_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s"
286        )
287
288    def subtract_temporals(self, internal_type, lhs, rhs):
289        lhs_sql, lhs_params = lhs
290        rhs_sql, rhs_params = rhs
291        if internal_type == "TimeField":
292            if self.connection.mysql_is_mariadb:
293                # MariaDB includes the microsecond component in TIME_TO_SEC as
294                # a decimal. MySQL returns an integer without microseconds.
295                return (
296                    f"CAST((TIME_TO_SEC({lhs_sql}) - TIME_TO_SEC({rhs_sql})) "
297                    "* 1000000 AS SIGNED)"
298                ), (
299                    *lhs_params,
300                    *rhs_params,
301                )
302            return (
303                f"((TIME_TO_SEC({lhs_sql}) * 1000000 + MICROSECOND({lhs_sql})) -"
304                f" (TIME_TO_SEC({rhs_sql}) * 1000000 + MICROSECOND({rhs_sql})))"
305            ), tuple(lhs_params) * 2 + tuple(rhs_params) * 2
306        params = (*rhs_params, *lhs_params)
307        return f"TIMESTAMPDIFF(MICROSECOND, {rhs_sql}, {lhs_sql})", params
308
309    def explain_query_prefix(self, format=None, **options):
310        # Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
311        if format and format.upper() == "TEXT":
312            format = "TRADITIONAL"
313        elif (
314            not format and "TREE" in self.connection.features.supported_explain_formats
315        ):
316            # Use TREE by default (if supported) as it's more informative.
317            format = "TREE"
318        analyze = options.pop("analyze", False)
319        prefix = super().explain_query_prefix(format, **options)
320        if analyze and self.connection.features.supports_explain_analyze:
321            # MariaDB uses ANALYZE instead of EXPLAIN ANALYZE.
322            prefix = (
323                "ANALYZE" if self.connection.mysql_is_mariadb else prefix + " ANALYZE"
324            )
325        if format and not (analyze and not self.connection.mysql_is_mariadb):
326            # Only MariaDB supports the analyze option with formats.
327            prefix += f" FORMAT={format}"
328        return prefix
329
330    def regex_lookup(self, lookup_type):
331        # REGEXP_LIKE doesn't exist in MariaDB.
332        if self.connection.mysql_is_mariadb:
333            if lookup_type == "regex":
334                return "%s REGEXP BINARY %s"
335            return "%s REGEXP %s"
336
337        match_option = "c" if lookup_type == "regex" else "i"
338        return f"REGEXP_LIKE(%s, %s, '{match_option}')"
339
340    def insert_statement(self, on_conflict=None):
341        if on_conflict == OnConflict.IGNORE:
342            return "INSERT IGNORE INTO"
343        return super().insert_statement(on_conflict=on_conflict)
344
345    def lookup_cast(self, lookup_type, internal_type=None):
346        lookup = "%s"
347        if internal_type == "JSONField":
348            if self.connection.mysql_is_mariadb or lookup_type in (
349                "iexact",
350                "contains",
351                "icontains",
352                "startswith",
353                "istartswith",
354                "endswith",
355                "iendswith",
356                "regex",
357                "iregex",
358            ):
359                lookup = "JSON_UNQUOTE(%s)"
360        return lookup
361
362    def conditional_expression_supported_in_where_clause(self, expression):
363        # MySQL ignores indexes with boolean fields unless they're compared
364        # directly to a boolean value.
365        if isinstance(expression, Exists | Lookup):
366            return True
367        if isinstance(expression, ExpressionWrapper) and expression.conditional:
368            return self.conditional_expression_supported_in_where_clause(
369                expression.expression
370            )
371        if getattr(expression, "conditional", False):
372            return False
373        return super().conditional_expression_supported_in_where_clause(expression)
374
375    def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
376        if on_conflict == OnConflict.UPDATE:
377            conflict_suffix_sql = "ON DUPLICATE KEY UPDATE %(fields)s"
378            # The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
379            # aliases for the new row and its columns available in MySQL
380            # 8.0.19+.
381            if not self.connection.mysql_is_mariadb:
382                if self.connection.mysql_version >= (8, 0, 19):
383                    conflict_suffix_sql = f"AS new {conflict_suffix_sql}"
384                    field_sql = "%(field)s = new.%(field)s"
385                else:
386                    field_sql = "%(field)s = VALUES(%(field)s)"
387            # Use VALUE() on MariaDB.
388            else:
389                field_sql = "%(field)s = VALUE(%(field)s)"
390
391            fields = ", ".join(
392                [
393                    field_sql % {"field": field}
394                    for field in map(self.quote_name, update_fields)
395                ]
396            )
397            return conflict_suffix_sql % {"fields": fields}
398        return super().on_conflict_suffix_sql(
399            fields,
400            on_conflict,
401            update_fields,
402            unique_fields,
403        )