1import uuid
2
3from plain.models.backends.base.operations import BaseDatabaseOperations
4from plain.models.backends.utils import split_tzname_delta
5from plain.models.constants import OnConflict
6from plain.models.expressions import Exists, ExpressionWrapper
7from plain.models.lookups import Lookup
8from plain.runtime import settings
9from plain.utils import timezone
10from plain.utils.encoding import force_str
11from plain.utils.regex_helper import _lazy_re_compile
12
13
14class DatabaseOperations(BaseDatabaseOperations):
15 compiler_module = "plain.models.backends.mysql.compiler"
16
17 # MySQL stores positive fields as UNSIGNED ints.
18 integer_field_ranges = {
19 **BaseDatabaseOperations.integer_field_ranges,
20 "PositiveSmallIntegerField": (0, 65535),
21 "PositiveIntegerField": (0, 4294967295),
22 "PositiveBigIntegerField": (0, 18446744073709551615),
23 }
24 cast_data_types = {
25 "AutoField": "signed integer",
26 "BigAutoField": "signed integer",
27 "SmallAutoField": "signed integer",
28 "CharField": "char(%(max_length)s)",
29 "DecimalField": "decimal(%(max_digits)s, %(decimal_places)s)",
30 "TextField": "char",
31 "IntegerField": "signed integer",
32 "BigIntegerField": "signed integer",
33 "SmallIntegerField": "signed integer",
34 "PositiveBigIntegerField": "unsigned integer",
35 "PositiveIntegerField": "unsigned integer",
36 "PositiveSmallIntegerField": "unsigned integer",
37 "DurationField": "signed integer",
38 }
39 cast_char_field_without_max_length = "char"
40 explain_prefix = "EXPLAIN"
41
42 # EXTRACT format cannot be passed in parameters.
43 _extract_format_re = _lazy_re_compile(r"[A-Z_]+")
44
45 def date_extract_sql(self, lookup_type, sql, params):
46 # https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
47 if lookup_type == "week_day":
48 # DAYOFWEEK() returns an integer, 1-7, Sunday=1.
49 return f"DAYOFWEEK({sql})", params
50 elif lookup_type == "iso_week_day":
51 # WEEKDAY() returns an integer, 0-6, Monday=0.
52 return f"WEEKDAY({sql}) + 1", params
53 elif lookup_type == "week":
54 # Override the value of default_week_format for consistency with
55 # other database backends.
56 # Mode 3: Monday, 1-53, with 4 or more days this year.
57 return f"WEEK({sql}, 3)", params
58 elif lookup_type == "iso_year":
59 # Get the year part from the YEARWEEK function, which returns a
60 # number as year * 100 + week.
61 return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params
62 else:
63 # EXTRACT returns 1-53 based on ISO-8601 for the week number.
64 lookup_type = lookup_type.upper()
65 if not self._extract_format_re.fullmatch(lookup_type):
66 raise ValueError(f"Invalid loookup type: {lookup_type!r}")
67 return f"EXTRACT({lookup_type} FROM {sql})", params
68
69 def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
70 sql, params = self._convert_sql_to_tz(sql, params, tzname)
71 fields = {
72 "year": "%Y-01-01",
73 "month": "%Y-%m-01",
74 }
75 if lookup_type in fields:
76 format_str = fields[lookup_type]
77 return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str)
78 elif lookup_type == "quarter":
79 return (
80 f"MAKEDATE(YEAR({sql}), 1) + "
81 f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER",
82 (*params, *params),
83 )
84 elif lookup_type == "week":
85 return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params)
86 else:
87 return f"DATE({sql})", params
88
89 def _prepare_tzname_delta(self, tzname):
90 tzname, sign, offset = split_tzname_delta(tzname)
91 return f"{sign}{offset}" if offset else tzname
92
93 def _convert_sql_to_tz(self, sql, params, tzname):
94 if tzname and settings.USE_TZ and self.connection.timezone_name != tzname:
95 return f"CONVERT_TZ({sql}, %s, %s)", (
96 *params,
97 self.connection.timezone_name,
98 self._prepare_tzname_delta(tzname),
99 )
100 return sql, params
101
102 def datetime_cast_date_sql(self, sql, params, tzname):
103 sql, params = self._convert_sql_to_tz(sql, params, tzname)
104 return f"DATE({sql})", params
105
106 def datetime_cast_time_sql(self, sql, params, tzname):
107 sql, params = self._convert_sql_to_tz(sql, params, tzname)
108 return f"TIME({sql})", params
109
110 def datetime_extract_sql(self, lookup_type, sql, params, tzname):
111 sql, params = self._convert_sql_to_tz(sql, params, tzname)
112 return self.date_extract_sql(lookup_type, sql, params)
113
114 def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
115 sql, params = self._convert_sql_to_tz(sql, params, tzname)
116 fields = ["year", "month", "day", "hour", "minute", "second"]
117 format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s")
118 format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
119 if lookup_type == "quarter":
120 return (
121 f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + "
122 f"INTERVAL QUARTER({sql}) QUARTER - "
123 f"INTERVAL 1 QUARTER, %s) AS DATETIME)"
124 ), (*params, *params, "%Y-%m-01 00:00:00")
125 if lookup_type == "week":
126 return (
127 f"CAST(DATE_FORMAT("
128 f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)"
129 ), (*params, *params, "%Y-%m-%d 00:00:00")
130 try:
131 i = fields.index(lookup_type) + 1
132 except ValueError:
133 pass
134 else:
135 format_str = "".join(format[:i] + format_def[i:])
136 return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str)
137 return sql, params
138
139 def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
140 sql, params = self._convert_sql_to_tz(sql, params, tzname)
141 fields = {
142 "hour": "%H:00:00",
143 "minute": "%H:%i:00",
144 "second": "%H:%i:%s",
145 }
146 if lookup_type in fields:
147 format_str = fields[lookup_type]
148 return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str)
149 else:
150 return f"TIME({sql})", params
151
152 def fetch_returned_insert_rows(self, cursor):
153 """
154 Given a cursor object that has just performed an INSERT...RETURNING
155 statement into a table, return the tuple of returned data.
156 """
157 return cursor.fetchall()
158
159 def format_for_duration_arithmetic(self, sql):
160 return "INTERVAL %s MICROSECOND" % sql
161
162 def force_no_ordering(self):
163 """
164 "ORDER BY NULL" prevents MySQL from implicitly ordering by grouped
165 columns. If no ordering would otherwise be applied, we don't want any
166 implicit sorting going on.
167 """
168 return [(None, ("NULL", [], False))]
169
170 def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
171 return value
172
173 def last_executed_query(self, cursor, sql, params):
174 # With MySQLdb, cursor objects have an (undocumented) "_executed"
175 # attribute where the exact query sent to the database is saved.
176 # See MySQLdb/cursors.py in the source distribution.
177 # MySQLdb returns string, PyMySQL bytes.
178 return force_str(getattr(cursor, "_executed", None), errors="replace")
179
180 def no_limit_value(self):
181 # 2**64 - 1, as recommended by the MySQL documentation
182 return 18446744073709551615
183
184 def quote_name(self, name):
185 if name.startswith("`") and name.endswith("`"):
186 return name # Quoting once is enough.
187 return "`%s`" % name
188
189 def return_insert_columns(self, fields):
190 # MySQL and MariaDB < 10.5.0 don't support an INSERT...RETURNING
191 # statement.
192 if not fields:
193 return "", ()
194 columns = [
195 "{}.{}".format(
196 self.quote_name(field.model._meta.db_table),
197 self.quote_name(field.column),
198 )
199 for field in fields
200 ]
201 return "RETURNING %s" % ", ".join(columns), ()
202
203 def sequence_reset_by_name_sql(self, style, sequences):
204 return [
205 "{} {} {} {} = 1;".format(
206 style.SQL_KEYWORD("ALTER"),
207 style.SQL_KEYWORD("TABLE"),
208 style.SQL_FIELD(self.quote_name(sequence_info["table"])),
209 style.SQL_FIELD("AUTO_INCREMENT"),
210 )
211 for sequence_info in sequences
212 ]
213
214 def validate_autopk_value(self, value):
215 # Zero in AUTO_INCREMENT field does not work without the
216 # NO_AUTO_VALUE_ON_ZERO SQL mode.
217 if value == 0 and not self.connection.features.allows_auto_pk_0:
218 raise ValueError(
219 "The database backend does not accept 0 as a value for AutoField."
220 )
221 return value
222
223 def adapt_datetimefield_value(self, value):
224 if value is None:
225 return None
226
227 # Expression values are adapted by the database.
228 if hasattr(value, "resolve_expression"):
229 return value
230
231 # MySQL doesn't support tz-aware datetimes
232 if timezone.is_aware(value):
233 if settings.USE_TZ:
234 value = timezone.make_naive(value, self.connection.timezone)
235 else:
236 raise ValueError(
237 "MySQL backend does not support timezone-aware datetimes when "
238 "USE_TZ is False."
239 )
240 return str(value)
241
242 def adapt_timefield_value(self, value):
243 if value is None:
244 return None
245
246 # Expression values are adapted by the database.
247 if hasattr(value, "resolve_expression"):
248 return value
249
250 # MySQL doesn't support tz-aware times
251 if timezone.is_aware(value):
252 raise ValueError("MySQL backend does not support timezone-aware times.")
253
254 return value.isoformat(timespec="microseconds")
255
256 def max_name_length(self):
257 return 64
258
259 def pk_default_value(self):
260 return "NULL"
261
262 def bulk_insert_sql(self, fields, placeholder_rows):
263 placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
264 values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
265 return "VALUES " + values_sql
266
267 def combine_expression(self, connector, sub_expressions):
268 if connector == "^":
269 return "POW(%s)" % ",".join(sub_expressions)
270 # Convert the result to a signed integer since MySQL's binary operators
271 # return an unsigned integer.
272 elif connector in ("&", "|", "<<", "#"):
273 connector = "^" if connector == "#" else connector
274 return "CONVERT(%s, SIGNED)" % connector.join(sub_expressions)
275 elif connector == ">>":
276 lhs, rhs = sub_expressions
277 return f"FLOOR({lhs} / POW(2, {rhs}))"
278 return super().combine_expression(connector, sub_expressions)
279
280 def get_db_converters(self, expression):
281 converters = super().get_db_converters(expression)
282 internal_type = expression.output_field.get_internal_type()
283 if internal_type == "BooleanField":
284 converters.append(self.convert_booleanfield_value)
285 elif internal_type == "DateTimeField":
286 if settings.USE_TZ:
287 converters.append(self.convert_datetimefield_value)
288 elif internal_type == "UUIDField":
289 converters.append(self.convert_uuidfield_value)
290 return converters
291
292 def convert_booleanfield_value(self, value, expression, connection):
293 if value in (0, 1):
294 value = bool(value)
295 return value
296
297 def convert_datetimefield_value(self, value, expression, connection):
298 if value is not None:
299 value = timezone.make_aware(value, self.connection.timezone)
300 return value
301
302 def convert_uuidfield_value(self, value, expression, connection):
303 if value is not None:
304 value = uuid.UUID(value)
305 return value
306
307 def binary_placeholder_sql(self, value):
308 return (
309 "_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s"
310 )
311
312 def subtract_temporals(self, internal_type, lhs, rhs):
313 lhs_sql, lhs_params = lhs
314 rhs_sql, rhs_params = rhs
315 if internal_type == "TimeField":
316 if self.connection.mysql_is_mariadb:
317 # MariaDB includes the microsecond component in TIME_TO_SEC as
318 # a decimal. MySQL returns an integer without microseconds.
319 return (
320 f"CAST((TIME_TO_SEC({lhs_sql}) - TIME_TO_SEC({rhs_sql})) "
321 "* 1000000 AS SIGNED)"
322 ), (
323 *lhs_params,
324 *rhs_params,
325 )
326 return (
327 f"((TIME_TO_SEC({lhs_sql}) * 1000000 + MICROSECOND({lhs_sql})) -"
328 f" (TIME_TO_SEC({rhs_sql}) * 1000000 + MICROSECOND({rhs_sql})))"
329 ), tuple(lhs_params) * 2 + tuple(rhs_params) * 2
330 params = (*rhs_params, *lhs_params)
331 return f"TIMESTAMPDIFF(MICROSECOND, {rhs_sql}, {lhs_sql})", params
332
333 def explain_query_prefix(self, format=None, **options):
334 # Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
335 if format and format.upper() == "TEXT":
336 format = "TRADITIONAL"
337 elif (
338 not format and "TREE" in self.connection.features.supported_explain_formats
339 ):
340 # Use TREE by default (if supported) as it's more informative.
341 format = "TREE"
342 analyze = options.pop("analyze", False)
343 prefix = super().explain_query_prefix(format, **options)
344 if analyze and self.connection.features.supports_explain_analyze:
345 # MariaDB uses ANALYZE instead of EXPLAIN ANALYZE.
346 prefix = (
347 "ANALYZE" if self.connection.mysql_is_mariadb else prefix + " ANALYZE"
348 )
349 if format and not (analyze and not self.connection.mysql_is_mariadb):
350 # Only MariaDB supports the analyze option with formats.
351 prefix += " FORMAT=%s" % format
352 return prefix
353
354 def regex_lookup(self, lookup_type):
355 # REGEXP_LIKE doesn't exist in MariaDB.
356 if self.connection.mysql_is_mariadb:
357 if lookup_type == "regex":
358 return "%s REGEXP BINARY %s"
359 return "%s REGEXP %s"
360
361 match_option = "c" if lookup_type == "regex" else "i"
362 return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option
363
364 def insert_statement(self, on_conflict=None):
365 if on_conflict == OnConflict.IGNORE:
366 return "INSERT IGNORE INTO"
367 return super().insert_statement(on_conflict=on_conflict)
368
369 def lookup_cast(self, lookup_type, internal_type=None):
370 lookup = "%s"
371 if internal_type == "JSONField":
372 if self.connection.mysql_is_mariadb or lookup_type in (
373 "iexact",
374 "contains",
375 "icontains",
376 "startswith",
377 "istartswith",
378 "endswith",
379 "iendswith",
380 "regex",
381 "iregex",
382 ):
383 lookup = "JSON_UNQUOTE(%s)"
384 return lookup
385
386 def conditional_expression_supported_in_where_clause(self, expression):
387 # MySQL ignores indexes with boolean fields unless they're compared
388 # directly to a boolean value.
389 if isinstance(expression, Exists | Lookup):
390 return True
391 if isinstance(expression, ExpressionWrapper) and expression.conditional:
392 return self.conditional_expression_supported_in_where_clause(
393 expression.expression
394 )
395 if getattr(expression, "conditional", False):
396 return False
397 return super().conditional_expression_supported_in_where_clause(expression)
398
399 def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
400 if on_conflict == OnConflict.UPDATE:
401 conflict_suffix_sql = "ON DUPLICATE KEY UPDATE %(fields)s"
402 # The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
403 # aliases for the new row and its columns available in MySQL
404 # 8.0.19+.
405 if not self.connection.mysql_is_mariadb:
406 if self.connection.mysql_version >= (8, 0, 19):
407 conflict_suffix_sql = f"AS new {conflict_suffix_sql}"
408 field_sql = "%(field)s = new.%(field)s"
409 else:
410 field_sql = "%(field)s = VALUES(%(field)s)"
411 # Use VALUE() on MariaDB.
412 else:
413 field_sql = "%(field)s = VALUE(%(field)s)"
414
415 fields = ", ".join(
416 [
417 field_sql % {"field": field}
418 for field in map(self.quote_name, update_fields)
419 ]
420 )
421 return conflict_suffix_sql % {"fields": fields}
422 return super().on_conflict_suffix_sql(
423 fields,
424 on_conflict,
425 update_fields,
426 unique_fields,
427 )