Plain is headed towards 1.0! Subscribe for development updates →

  1import uuid
  2
  3from plain.models.backends.base.operations import BaseDatabaseOperations
  4from plain.models.backends.utils import split_tzname_delta
  5from plain.models.constants import OnConflict
  6from plain.models.expressions import Exists, ExpressionWrapper
  7from plain.models.lookups import Lookup
  8from plain.runtime import settings
  9from plain.utils import timezone
 10from plain.utils.encoding import force_str
 11from plain.utils.regex_helper import _lazy_re_compile
 12
 13
 14class DatabaseOperations(BaseDatabaseOperations):
 15    compiler_module = "plain.models.backends.mysql.compiler"
 16
 17    # MySQL stores positive fields as UNSIGNED ints.
 18    integer_field_ranges = {
 19        **BaseDatabaseOperations.integer_field_ranges,
 20        "PositiveSmallIntegerField": (0, 65535),
 21        "PositiveIntegerField": (0, 4294967295),
 22        "PositiveBigIntegerField": (0, 18446744073709551615),
 23    }
 24    cast_data_types = {
 25        "AutoField": "signed integer",
 26        "BigAutoField": "signed integer",
 27        "SmallAutoField": "signed integer",
 28        "CharField": "char(%(max_length)s)",
 29        "DecimalField": "decimal(%(max_digits)s, %(decimal_places)s)",
 30        "TextField": "char",
 31        "IntegerField": "signed integer",
 32        "BigIntegerField": "signed integer",
 33        "SmallIntegerField": "signed integer",
 34        "PositiveBigIntegerField": "unsigned integer",
 35        "PositiveIntegerField": "unsigned integer",
 36        "PositiveSmallIntegerField": "unsigned integer",
 37        "DurationField": "signed integer",
 38    }
 39    cast_char_field_without_max_length = "char"
 40    explain_prefix = "EXPLAIN"
 41
 42    # EXTRACT format cannot be passed in parameters.
 43    _extract_format_re = _lazy_re_compile(r"[A-Z_]+")
 44
 45    def date_extract_sql(self, lookup_type, sql, params):
 46        # https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
 47        if lookup_type == "week_day":
 48            # DAYOFWEEK() returns an integer, 1-7, Sunday=1.
 49            return f"DAYOFWEEK({sql})", params
 50        elif lookup_type == "iso_week_day":
 51            # WEEKDAY() returns an integer, 0-6, Monday=0.
 52            return f"WEEKDAY({sql}) + 1", params
 53        elif lookup_type == "week":
 54            # Override the value of default_week_format for consistency with
 55            # other database backends.
 56            # Mode 3: Monday, 1-53, with 4 or more days this year.
 57            return f"WEEK({sql}, 3)", params
 58        elif lookup_type == "iso_year":
 59            # Get the year part from the YEARWEEK function, which returns a
 60            # number as year * 100 + week.
 61            return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params
 62        else:
 63            # EXTRACT returns 1-53 based on ISO-8601 for the week number.
 64            lookup_type = lookup_type.upper()
 65            if not self._extract_format_re.fullmatch(lookup_type):
 66                raise ValueError(f"Invalid loookup type: {lookup_type!r}")
 67            return f"EXTRACT({lookup_type} FROM {sql})", params
 68
 69    def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
 70        sql, params = self._convert_sql_to_tz(sql, params, tzname)
 71        fields = {
 72            "year": "%Y-01-01",
 73            "month": "%Y-%m-01",
 74        }
 75        if lookup_type in fields:
 76            format_str = fields[lookup_type]
 77            return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str)
 78        elif lookup_type == "quarter":
 79            return (
 80                f"MAKEDATE(YEAR({sql}), 1) + "
 81                f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER",
 82                (*params, *params),
 83            )
 84        elif lookup_type == "week":
 85            return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params)
 86        else:
 87            return f"DATE({sql})", params
 88
 89    def _prepare_tzname_delta(self, tzname):
 90        tzname, sign, offset = split_tzname_delta(tzname)
 91        return f"{sign}{offset}" if offset else tzname
 92
 93    def _convert_sql_to_tz(self, sql, params, tzname):
 94        if tzname and settings.USE_TZ and self.connection.timezone_name != tzname:
 95            return f"CONVERT_TZ({sql}, %s, %s)", (
 96                *params,
 97                self.connection.timezone_name,
 98                self._prepare_tzname_delta(tzname),
 99            )
100        return sql, params
101
102    def datetime_cast_date_sql(self, sql, params, tzname):
103        sql, params = self._convert_sql_to_tz(sql, params, tzname)
104        return f"DATE({sql})", params
105
106    def datetime_cast_time_sql(self, sql, params, tzname):
107        sql, params = self._convert_sql_to_tz(sql, params, tzname)
108        return f"TIME({sql})", params
109
110    def datetime_extract_sql(self, lookup_type, sql, params, tzname):
111        sql, params = self._convert_sql_to_tz(sql, params, tzname)
112        return self.date_extract_sql(lookup_type, sql, params)
113
114    def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
115        sql, params = self._convert_sql_to_tz(sql, params, tzname)
116        fields = ["year", "month", "day", "hour", "minute", "second"]
117        format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s")
118        format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
119        if lookup_type == "quarter":
120            return (
121                f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + "
122                f"INTERVAL QUARTER({sql}) QUARTER - "
123                f"INTERVAL 1 QUARTER, %s) AS DATETIME)"
124            ), (*params, *params, "%Y-%m-01 00:00:00")
125        if lookup_type == "week":
126            return (
127                f"CAST(DATE_FORMAT("
128                f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)"
129            ), (*params, *params, "%Y-%m-%d 00:00:00")
130        try:
131            i = fields.index(lookup_type) + 1
132        except ValueError:
133            pass
134        else:
135            format_str = "".join(format[:i] + format_def[i:])
136            return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str)
137        return sql, params
138
139    def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
140        sql, params = self._convert_sql_to_tz(sql, params, tzname)
141        fields = {
142            "hour": "%H:00:00",
143            "minute": "%H:%i:00",
144            "second": "%H:%i:%s",
145        }
146        if lookup_type in fields:
147            format_str = fields[lookup_type]
148            return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str)
149        else:
150            return f"TIME({sql})", params
151
152    def fetch_returned_insert_rows(self, cursor):
153        """
154        Given a cursor object that has just performed an INSERT...RETURNING
155        statement into a table, return the tuple of returned data.
156        """
157        return cursor.fetchall()
158
159    def format_for_duration_arithmetic(self, sql):
160        return "INTERVAL %s MICROSECOND" % sql
161
162    def force_no_ordering(self):
163        """
164        "ORDER BY NULL" prevents MySQL from implicitly ordering by grouped
165        columns. If no ordering would otherwise be applied, we don't want any
166        implicit sorting going on.
167        """
168        return [(None, ("NULL", [], False))]
169
170    def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
171        return value
172
173    def last_executed_query(self, cursor, sql, params):
174        # With MySQLdb, cursor objects have an (undocumented) "_executed"
175        # attribute where the exact query sent to the database is saved.
176        # See MySQLdb/cursors.py in the source distribution.
177        # MySQLdb returns string, PyMySQL bytes.
178        return force_str(getattr(cursor, "_executed", None), errors="replace")
179
180    def no_limit_value(self):
181        # 2**64 - 1, as recommended by the MySQL documentation
182        return 18446744073709551615
183
184    def quote_name(self, name):
185        if name.startswith("`") and name.endswith("`"):
186            return name  # Quoting once is enough.
187        return "`%s`" % name
188
189    def return_insert_columns(self, fields):
190        # MySQL and MariaDB < 10.5.0 don't support an INSERT...RETURNING
191        # statement.
192        if not fields:
193            return "", ()
194        columns = [
195            "{}.{}".format(
196                self.quote_name(field.model._meta.db_table),
197                self.quote_name(field.column),
198            )
199            for field in fields
200        ]
201        return "RETURNING %s" % ", ".join(columns), ()
202
203    def sequence_reset_by_name_sql(self, style, sequences):
204        return [
205            "{} {} {} {} = 1;".format(
206                style.SQL_KEYWORD("ALTER"),
207                style.SQL_KEYWORD("TABLE"),
208                style.SQL_FIELD(self.quote_name(sequence_info["table"])),
209                style.SQL_FIELD("AUTO_INCREMENT"),
210            )
211            for sequence_info in sequences
212        ]
213
214    def validate_autopk_value(self, value):
215        # Zero in AUTO_INCREMENT field does not work without the
216        # NO_AUTO_VALUE_ON_ZERO SQL mode.
217        if value == 0 and not self.connection.features.allows_auto_pk_0:
218            raise ValueError(
219                "The database backend does not accept 0 as a value for AutoField."
220            )
221        return value
222
223    def adapt_datetimefield_value(self, value):
224        if value is None:
225            return None
226
227        # Expression values are adapted by the database.
228        if hasattr(value, "resolve_expression"):
229            return value
230
231        # MySQL doesn't support tz-aware datetimes
232        if timezone.is_aware(value):
233            if settings.USE_TZ:
234                value = timezone.make_naive(value, self.connection.timezone)
235            else:
236                raise ValueError(
237                    "MySQL backend does not support timezone-aware datetimes when "
238                    "USE_TZ is False."
239                )
240        return str(value)
241
242    def adapt_timefield_value(self, value):
243        if value is None:
244            return None
245
246        # Expression values are adapted by the database.
247        if hasattr(value, "resolve_expression"):
248            return value
249
250        # MySQL doesn't support tz-aware times
251        if timezone.is_aware(value):
252            raise ValueError("MySQL backend does not support timezone-aware times.")
253
254        return value.isoformat(timespec="microseconds")
255
256    def max_name_length(self):
257        return 64
258
259    def pk_default_value(self):
260        return "NULL"
261
262    def bulk_insert_sql(self, fields, placeholder_rows):
263        placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
264        values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
265        return "VALUES " + values_sql
266
267    def combine_expression(self, connector, sub_expressions):
268        if connector == "^":
269            return "POW(%s)" % ",".join(sub_expressions)
270        # Convert the result to a signed integer since MySQL's binary operators
271        # return an unsigned integer.
272        elif connector in ("&", "|", "<<", "#"):
273            connector = "^" if connector == "#" else connector
274            return "CONVERT(%s, SIGNED)" % connector.join(sub_expressions)
275        elif connector == ">>":
276            lhs, rhs = sub_expressions
277            return f"FLOOR({lhs} / POW(2, {rhs}))"
278        return super().combine_expression(connector, sub_expressions)
279
280    def get_db_converters(self, expression):
281        converters = super().get_db_converters(expression)
282        internal_type = expression.output_field.get_internal_type()
283        if internal_type == "BooleanField":
284            converters.append(self.convert_booleanfield_value)
285        elif internal_type == "DateTimeField":
286            if settings.USE_TZ:
287                converters.append(self.convert_datetimefield_value)
288        elif internal_type == "UUIDField":
289            converters.append(self.convert_uuidfield_value)
290        return converters
291
292    def convert_booleanfield_value(self, value, expression, connection):
293        if value in (0, 1):
294            value = bool(value)
295        return value
296
297    def convert_datetimefield_value(self, value, expression, connection):
298        if value is not None:
299            value = timezone.make_aware(value, self.connection.timezone)
300        return value
301
302    def convert_uuidfield_value(self, value, expression, connection):
303        if value is not None:
304            value = uuid.UUID(value)
305        return value
306
307    def binary_placeholder_sql(self, value):
308        return (
309            "_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s"
310        )
311
312    def subtract_temporals(self, internal_type, lhs, rhs):
313        lhs_sql, lhs_params = lhs
314        rhs_sql, rhs_params = rhs
315        if internal_type == "TimeField":
316            if self.connection.mysql_is_mariadb:
317                # MariaDB includes the microsecond component in TIME_TO_SEC as
318                # a decimal. MySQL returns an integer without microseconds.
319                return (
320                    f"CAST((TIME_TO_SEC({lhs_sql}) - TIME_TO_SEC({rhs_sql})) "
321                    "* 1000000 AS SIGNED)"
322                ), (
323                    *lhs_params,
324                    *rhs_params,
325                )
326            return (
327                f"((TIME_TO_SEC({lhs_sql}) * 1000000 + MICROSECOND({lhs_sql})) -"
328                f" (TIME_TO_SEC({rhs_sql}) * 1000000 + MICROSECOND({rhs_sql})))"
329            ), tuple(lhs_params) * 2 + tuple(rhs_params) * 2
330        params = (*rhs_params, *lhs_params)
331        return f"TIMESTAMPDIFF(MICROSECOND, {rhs_sql}, {lhs_sql})", params
332
333    def explain_query_prefix(self, format=None, **options):
334        # Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
335        if format and format.upper() == "TEXT":
336            format = "TRADITIONAL"
337        elif (
338            not format and "TREE" in self.connection.features.supported_explain_formats
339        ):
340            # Use TREE by default (if supported) as it's more informative.
341            format = "TREE"
342        analyze = options.pop("analyze", False)
343        prefix = super().explain_query_prefix(format, **options)
344        if analyze and self.connection.features.supports_explain_analyze:
345            # MariaDB uses ANALYZE instead of EXPLAIN ANALYZE.
346            prefix = (
347                "ANALYZE" if self.connection.mysql_is_mariadb else prefix + " ANALYZE"
348            )
349        if format and not (analyze and not self.connection.mysql_is_mariadb):
350            # Only MariaDB supports the analyze option with formats.
351            prefix += " FORMAT=%s" % format
352        return prefix
353
354    def regex_lookup(self, lookup_type):
355        # REGEXP_LIKE doesn't exist in MariaDB.
356        if self.connection.mysql_is_mariadb:
357            if lookup_type == "regex":
358                return "%s REGEXP BINARY %s"
359            return "%s REGEXP %s"
360
361        match_option = "c" if lookup_type == "regex" else "i"
362        return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option
363
364    def insert_statement(self, on_conflict=None):
365        if on_conflict == OnConflict.IGNORE:
366            return "INSERT IGNORE INTO"
367        return super().insert_statement(on_conflict=on_conflict)
368
369    def lookup_cast(self, lookup_type, internal_type=None):
370        lookup = "%s"
371        if internal_type == "JSONField":
372            if self.connection.mysql_is_mariadb or lookup_type in (
373                "iexact",
374                "contains",
375                "icontains",
376                "startswith",
377                "istartswith",
378                "endswith",
379                "iendswith",
380                "regex",
381                "iregex",
382            ):
383                lookup = "JSON_UNQUOTE(%s)"
384        return lookup
385
386    def conditional_expression_supported_in_where_clause(self, expression):
387        # MySQL ignores indexes with boolean fields unless they're compared
388        # directly to a boolean value.
389        if isinstance(expression, Exists | Lookup):
390            return True
391        if isinstance(expression, ExpressionWrapper) and expression.conditional:
392            return self.conditional_expression_supported_in_where_clause(
393                expression.expression
394            )
395        if getattr(expression, "conditional", False):
396            return False
397        return super().conditional_expression_supported_in_where_clause(expression)
398
399    def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
400        if on_conflict == OnConflict.UPDATE:
401            conflict_suffix_sql = "ON DUPLICATE KEY UPDATE %(fields)s"
402            # The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
403            # aliases for the new row and its columns available in MySQL
404            # 8.0.19+.
405            if not self.connection.mysql_is_mariadb:
406                if self.connection.mysql_version >= (8, 0, 19):
407                    conflict_suffix_sql = f"AS new {conflict_suffix_sql}"
408                    field_sql = "%(field)s = new.%(field)s"
409                else:
410                    field_sql = "%(field)s = VALUES(%(field)s)"
411            # Use VALUE() on MariaDB.
412            else:
413                field_sql = "%(field)s = VALUE(%(field)s)"
414
415            fields = ", ".join(
416                [
417                    field_sql % {"field": field}
418                    for field in map(self.quote_name, update_fields)
419                ]
420            )
421            return conflict_suffix_sql % {"fields": fields}
422        return super().on_conflict_suffix_sql(
423            fields,
424            on_conflict,
425            update_fields,
426            unique_fields,
427        )