1import uuid
2
3from plain.models.backends.base.operations import BaseDatabaseOperations
4from plain.models.backends.utils import split_tzname_delta
5from plain.models.constants import OnConflict
6from plain.models.expressions import Exists, ExpressionWrapper
7from plain.models.lookups import Lookup
8from plain.utils import timezone
9from plain.utils.encoding import force_str
10from plain.utils.regex_helper import _lazy_re_compile
11
12
13class DatabaseOperations(BaseDatabaseOperations):
14 compiler_module = "plain.models.backends.mysql.compiler"
15
16 # MySQL stores positive fields as UNSIGNED ints.
17 integer_field_ranges = {
18 **BaseDatabaseOperations.integer_field_ranges,
19 "PositiveSmallIntegerField": (0, 65535),
20 "PositiveIntegerField": (0, 4294967295),
21 "PositiveBigIntegerField": (0, 18446744073709551615),
22 }
23 cast_data_types = {
24 "AutoField": "signed integer",
25 "BigAutoField": "signed integer",
26 "SmallAutoField": "signed integer",
27 "CharField": "char(%(max_length)s)",
28 "DecimalField": "decimal(%(max_digits)s, %(decimal_places)s)",
29 "TextField": "char",
30 "IntegerField": "signed integer",
31 "BigIntegerField": "signed integer",
32 "SmallIntegerField": "signed integer",
33 "PositiveBigIntegerField": "unsigned integer",
34 "PositiveIntegerField": "unsigned integer",
35 "PositiveSmallIntegerField": "unsigned integer",
36 "DurationField": "signed integer",
37 }
38 cast_char_field_without_max_length = "char"
39 explain_prefix = "EXPLAIN"
40
41 # EXTRACT format cannot be passed in parameters.
42 _extract_format_re = _lazy_re_compile(r"[A-Z_]+")
43
44 def date_extract_sql(self, lookup_type, sql, params):
45 # https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
46 if lookup_type == "week_day":
47 # DAYOFWEEK() returns an integer, 1-7, Sunday=1.
48 return f"DAYOFWEEK({sql})", params
49 elif lookup_type == "iso_week_day":
50 # WEEKDAY() returns an integer, 0-6, Monday=0.
51 return f"WEEKDAY({sql}) + 1", params
52 elif lookup_type == "week":
53 # Override the value of default_week_format for consistency with
54 # other database backends.
55 # Mode 3: Monday, 1-53, with 4 or more days this year.
56 return f"WEEK({sql}, 3)", params
57 elif lookup_type == "iso_year":
58 # Get the year part from the YEARWEEK function, which returns a
59 # number as year * 100 + week.
60 return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params
61 else:
62 # EXTRACT returns 1-53 based on ISO-8601 for the week number.
63 lookup_type = lookup_type.upper()
64 if not self._extract_format_re.fullmatch(lookup_type):
65 raise ValueError(f"Invalid loookup type: {lookup_type!r}")
66 return f"EXTRACT({lookup_type} FROM {sql})", params
67
68 def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
69 sql, params = self._convert_sql_to_tz(sql, params, tzname)
70 fields = {
71 "year": "%Y-01-01",
72 "month": "%Y-%m-01",
73 }
74 if lookup_type in fields:
75 format_str = fields[lookup_type]
76 return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str)
77 elif lookup_type == "quarter":
78 return (
79 f"MAKEDATE(YEAR({sql}), 1) + "
80 f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER",
81 (*params, *params),
82 )
83 elif lookup_type == "week":
84 return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params)
85 else:
86 return f"DATE({sql})", params
87
88 def _prepare_tzname_delta(self, tzname):
89 tzname, sign, offset = split_tzname_delta(tzname)
90 return f"{sign}{offset}" if offset else tzname
91
92 def _convert_sql_to_tz(self, sql, params, tzname):
93 if tzname and self.connection.timezone_name != tzname:
94 return f"CONVERT_TZ({sql}, %s, %s)", (
95 *params,
96 self.connection.timezone_name,
97 self._prepare_tzname_delta(tzname),
98 )
99 return sql, params
100
101 def datetime_cast_date_sql(self, sql, params, tzname):
102 sql, params = self._convert_sql_to_tz(sql, params, tzname)
103 return f"DATE({sql})", params
104
105 def datetime_cast_time_sql(self, sql, params, tzname):
106 sql, params = self._convert_sql_to_tz(sql, params, tzname)
107 return f"TIME({sql})", params
108
109 def datetime_extract_sql(self, lookup_type, sql, params, tzname):
110 sql, params = self._convert_sql_to_tz(sql, params, tzname)
111 return self.date_extract_sql(lookup_type, sql, params)
112
113 def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
114 sql, params = self._convert_sql_to_tz(sql, params, tzname)
115 fields = ["year", "month", "day", "hour", "minute", "second"]
116 format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s")
117 format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
118 if lookup_type == "quarter":
119 return (
120 f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + "
121 f"INTERVAL QUARTER({sql}) QUARTER - "
122 f"INTERVAL 1 QUARTER, %s) AS DATETIME)"
123 ), (*params, *params, "%Y-%m-01 00:00:00")
124 if lookup_type == "week":
125 return (
126 f"CAST(DATE_FORMAT("
127 f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)"
128 ), (*params, *params, "%Y-%m-%d 00:00:00")
129 try:
130 i = fields.index(lookup_type) + 1
131 except ValueError:
132 pass
133 else:
134 format_str = "".join(format[:i] + format_def[i:])
135 return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str)
136 return sql, params
137
138 def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
139 sql, params = self._convert_sql_to_tz(sql, params, tzname)
140 fields = {
141 "hour": "%H:00:00",
142 "minute": "%H:%i:00",
143 "second": "%H:%i:%s",
144 }
145 if lookup_type in fields:
146 format_str = fields[lookup_type]
147 return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str)
148 else:
149 return f"TIME({sql})", params
150
151 def fetch_returned_insert_rows(self, cursor):
152 """
153 Given a cursor object that has just performed an INSERT...RETURNING
154 statement into a table, return the tuple of returned data.
155 """
156 return cursor.fetchall()
157
158 def format_for_duration_arithmetic(self, sql):
159 return f"INTERVAL {sql} MICROSECOND"
160
161 def force_no_ordering(self):
162 """
163 "ORDER BY NULL" prevents MySQL from implicitly ordering by grouped
164 columns. If no ordering would otherwise be applied, we don't want any
165 implicit sorting going on.
166 """
167 return [(None, ("NULL", [], False))]
168
169 def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
170 return value
171
172 def last_executed_query(self, cursor, sql, params):
173 # With MySQLdb, cursor objects have an (undocumented) "_executed"
174 # attribute where the exact query sent to the database is saved.
175 # See MySQLdb/cursors.py in the source distribution.
176 # MySQLdb returns string, PyMySQL bytes.
177 return force_str(getattr(cursor, "_executed", None), errors="replace")
178
179 def no_limit_value(self):
180 # 2**64 - 1, as recommended by the MySQL documentation
181 return 18446744073709551615
182
183 def quote_name(self, name):
184 if name.startswith("`") and name.endswith("`"):
185 return name # Quoting once is enough.
186 return f"`{name}`"
187
188 def return_insert_columns(self, fields):
189 # MySQL and MariaDB < 10.5.0 don't support an INSERT...RETURNING
190 # statement.
191 if not fields:
192 return "", ()
193 columns = [
194 f"{self.quote_name(field.model._meta.db_table)}.{self.quote_name(field.column)}"
195 for field in fields
196 ]
197 return "RETURNING {}".format(", ".join(columns)), ()
198
199 def validate_autopk_value(self, value):
200 # Zero in AUTO_INCREMENT field does not work without the
201 # NO_AUTO_VALUE_ON_ZERO SQL mode.
202 if value == 0 and not self.connection.features.allows_auto_pk_0:
203 raise ValueError(
204 "The database backend does not accept 0 as a value for AutoField."
205 )
206 return value
207
208 def adapt_datetimefield_value(self, value):
209 if value is None:
210 return None
211
212 # Expression values are adapted by the database.
213 if hasattr(value, "resolve_expression"):
214 return value
215
216 # MySQL doesn't support tz-aware datetimes
217 if timezone.is_aware(value):
218 value = timezone.make_naive(value, self.connection.timezone)
219 return str(value)
220
221 def adapt_timefield_value(self, value):
222 if value is None:
223 return None
224
225 # Expression values are adapted by the database.
226 if hasattr(value, "resolve_expression"):
227 return value
228
229 # MySQL doesn't support tz-aware times
230 if timezone.is_aware(value):
231 raise ValueError("MySQL backend does not support timezone-aware times.")
232
233 return value.isoformat(timespec="microseconds")
234
235 def max_name_length(self):
236 return 64
237
238 def pk_default_value(self):
239 return "NULL"
240
241 def bulk_insert_sql(self, fields, placeholder_rows):
242 placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
243 values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
244 return "VALUES " + values_sql
245
246 def combine_expression(self, connector, sub_expressions):
247 if connector == "^":
248 return "POW({})".format(",".join(sub_expressions))
249 # Convert the result to a signed integer since MySQL's binary operators
250 # return an unsigned integer.
251 elif connector in ("&", "|", "<<", "#"):
252 connector = "^" if connector == "#" else connector
253 return f"CONVERT({connector.join(sub_expressions)}, SIGNED)"
254 elif connector == ">>":
255 lhs, rhs = sub_expressions
256 return f"FLOOR({lhs} / POW(2, {rhs}))"
257 return super().combine_expression(connector, sub_expressions)
258
259 def get_db_converters(self, expression):
260 converters = super().get_db_converters(expression)
261 internal_type = expression.output_field.get_internal_type()
262 if internal_type == "BooleanField":
263 converters.append(self.convert_booleanfield_value)
264 elif internal_type == "DateTimeField":
265 converters.append(self.convert_datetimefield_value)
266 elif internal_type == "UUIDField":
267 converters.append(self.convert_uuidfield_value)
268 return converters
269
270 def convert_booleanfield_value(self, value, expression, connection):
271 if value in (0, 1):
272 value = bool(value)
273 return value
274
275 def convert_datetimefield_value(self, value, expression, connection):
276 if value is not None:
277 value = timezone.make_aware(value, self.connection.timezone)
278 return value
279
280 def convert_uuidfield_value(self, value, expression, connection):
281 if value is not None:
282 value = uuid.UUID(value)
283 return value
284
285 def binary_placeholder_sql(self, value):
286 return (
287 "_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s"
288 )
289
290 def subtract_temporals(self, internal_type, lhs, rhs):
291 lhs_sql, lhs_params = lhs
292 rhs_sql, rhs_params = rhs
293 if internal_type == "TimeField":
294 if self.connection.mysql_is_mariadb:
295 # MariaDB includes the microsecond component in TIME_TO_SEC as
296 # a decimal. MySQL returns an integer without microseconds.
297 return (
298 f"CAST((TIME_TO_SEC({lhs_sql}) - TIME_TO_SEC({rhs_sql})) "
299 "* 1000000 AS SIGNED)"
300 ), (
301 *lhs_params,
302 *rhs_params,
303 )
304 return (
305 f"((TIME_TO_SEC({lhs_sql}) * 1000000 + MICROSECOND({lhs_sql})) -"
306 f" (TIME_TO_SEC({rhs_sql}) * 1000000 + MICROSECOND({rhs_sql})))"
307 ), tuple(lhs_params) * 2 + tuple(rhs_params) * 2
308 params = (*rhs_params, *lhs_params)
309 return f"TIMESTAMPDIFF(MICROSECOND, {rhs_sql}, {lhs_sql})", params
310
311 def explain_query_prefix(self, format=None, **options):
312 # Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
313 if format and format.upper() == "TEXT":
314 format = "TRADITIONAL"
315 elif (
316 not format and "TREE" in self.connection.features.supported_explain_formats
317 ):
318 # Use TREE by default (if supported) as it's more informative.
319 format = "TREE"
320 analyze = options.pop("analyze", False)
321 prefix = super().explain_query_prefix(format, **options)
322 if analyze and self.connection.features.supports_explain_analyze:
323 # MariaDB uses ANALYZE instead of EXPLAIN ANALYZE.
324 prefix = (
325 "ANALYZE" if self.connection.mysql_is_mariadb else prefix + " ANALYZE"
326 )
327 if format and not (analyze and not self.connection.mysql_is_mariadb):
328 # Only MariaDB supports the analyze option with formats.
329 prefix += f" FORMAT={format}"
330 return prefix
331
332 def regex_lookup(self, lookup_type):
333 # REGEXP_LIKE doesn't exist in MariaDB.
334 if self.connection.mysql_is_mariadb:
335 if lookup_type == "regex":
336 return "%s REGEXP BINARY %s"
337 return "%s REGEXP %s"
338
339 match_option = "c" if lookup_type == "regex" else "i"
340 return f"REGEXP_LIKE(%s, %s, '{match_option}')"
341
342 def insert_statement(self, on_conflict=None):
343 if on_conflict == OnConflict.IGNORE:
344 return "INSERT IGNORE INTO"
345 return super().insert_statement(on_conflict=on_conflict)
346
347 def lookup_cast(self, lookup_type, internal_type=None):
348 lookup = "%s"
349 if internal_type == "JSONField":
350 if self.connection.mysql_is_mariadb or lookup_type in (
351 "iexact",
352 "contains",
353 "icontains",
354 "startswith",
355 "istartswith",
356 "endswith",
357 "iendswith",
358 "regex",
359 "iregex",
360 ):
361 lookup = "JSON_UNQUOTE(%s)"
362 return lookup
363
364 def conditional_expression_supported_in_where_clause(self, expression):
365 # MySQL ignores indexes with boolean fields unless they're compared
366 # directly to a boolean value.
367 if isinstance(expression, Exists | Lookup):
368 return True
369 if isinstance(expression, ExpressionWrapper) and expression.conditional:
370 return self.conditional_expression_supported_in_where_clause(
371 expression.expression
372 )
373 if getattr(expression, "conditional", False):
374 return False
375 return super().conditional_expression_supported_in_where_clause(expression)
376
377 def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
378 if on_conflict == OnConflict.UPDATE:
379 conflict_suffix_sql = "ON DUPLICATE KEY UPDATE %(fields)s"
380 # The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
381 # aliases for the new row and its columns available in MySQL
382 # 8.0.19+.
383 if not self.connection.mysql_is_mariadb:
384 if self.connection.mysql_version >= (8, 0, 19):
385 conflict_suffix_sql = f"AS new {conflict_suffix_sql}"
386 field_sql = "%(field)s = new.%(field)s"
387 else:
388 field_sql = "%(field)s = VALUES(%(field)s)"
389 # Use VALUE() on MariaDB.
390 else:
391 field_sql = "%(field)s = VALUE(%(field)s)"
392
393 fields = ", ".join(
394 [
395 field_sql % {"field": field}
396 for field in map(self.quote_name, update_fields)
397 ]
398 )
399 return conflict_suffix_sql % {"fields": fields}
400 return super().on_conflict_suffix_sql(
401 fields,
402 on_conflict,
403 update_fields,
404 unique_fields,
405 )