1import uuid
2
3from plain.models.backends.base.operations import BaseDatabaseOperations
4from plain.models.backends.utils import split_tzname_delta
5from plain.models.constants import OnConflict
6from plain.models.expressions import Exists, ExpressionWrapper
7from plain.models.lookups import Lookup
8from plain.utils import timezone
9from plain.utils.encoding import force_str
10from plain.utils.regex_helper import _lazy_re_compile
11
12
13class DatabaseOperations(BaseDatabaseOperations):
14 compiler_module = "plain.models.backends.mysql.compiler"
15
16 # MySQL stores positive fields as UNSIGNED ints.
17 integer_field_ranges = {
18 **BaseDatabaseOperations.integer_field_ranges,
19 "PositiveSmallIntegerField": (0, 65535),
20 "PositiveIntegerField": (0, 4294967295),
21 "PositiveBigIntegerField": (0, 18446744073709551615),
22 }
23 cast_data_types = {
24 "PrimaryKeyField": "signed integer",
25 "CharField": "char(%(max_length)s)",
26 "DecimalField": "decimal(%(max_digits)s, %(decimal_places)s)",
27 "TextField": "char",
28 "IntegerField": "signed integer",
29 "BigIntegerField": "signed integer",
30 "SmallIntegerField": "signed integer",
31 "PositiveBigIntegerField": "unsigned integer",
32 "PositiveIntegerField": "unsigned integer",
33 "PositiveSmallIntegerField": "unsigned integer",
34 "DurationField": "signed integer",
35 }
36 cast_char_field_without_max_length = "char"
37 explain_prefix = "EXPLAIN"
38
39 # EXTRACT format cannot be passed in parameters.
40 _extract_format_re = _lazy_re_compile(r"[A-Z_]+")
41
42 def date_extract_sql(self, lookup_type, sql, params):
43 # https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
44 if lookup_type == "week_day":
45 # DAYOFWEEK() returns an integer, 1-7, Sunday=1.
46 return f"DAYOFWEEK({sql})", params
47 elif lookup_type == "iso_week_day":
48 # WEEKDAY() returns an integer, 0-6, Monday=0.
49 return f"WEEKDAY({sql}) + 1", params
50 elif lookup_type == "week":
51 # Override the value of default_week_format for consistency with
52 # other database backends.
53 # Mode 3: Monday, 1-53, with 4 or more days this year.
54 return f"WEEK({sql}, 3)", params
55 elif lookup_type == "iso_year":
56 # Get the year part from the YEARWEEK function, which returns a
57 # number as year * 100 + week.
58 return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params
59 else:
60 # EXTRACT returns 1-53 based on ISO-8601 for the week number.
61 lookup_type = lookup_type.upper()
62 if not self._extract_format_re.fullmatch(lookup_type):
63 raise ValueError(f"Invalid loookup type: {lookup_type!r}")
64 return f"EXTRACT({lookup_type} FROM {sql})", params
65
66 def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
67 sql, params = self._convert_sql_to_tz(sql, params, tzname)
68 fields = {
69 "year": "%Y-01-01",
70 "month": "%Y-%m-01",
71 }
72 if lookup_type in fields:
73 format_str = fields[lookup_type]
74 return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str)
75 elif lookup_type == "quarter":
76 return (
77 f"MAKEDATE(YEAR({sql}), 1) + "
78 f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER",
79 (*params, *params),
80 )
81 elif lookup_type == "week":
82 return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params)
83 else:
84 return f"DATE({sql})", params
85
86 def _prepare_tzname_delta(self, tzname):
87 tzname, sign, offset = split_tzname_delta(tzname)
88 return f"{sign}{offset}" if offset else tzname
89
90 def _convert_sql_to_tz(self, sql, params, tzname):
91 if tzname and self.connection.timezone_name != tzname:
92 return f"CONVERT_TZ({sql}, %s, %s)", (
93 *params,
94 self.connection.timezone_name,
95 self._prepare_tzname_delta(tzname),
96 )
97 return sql, params
98
99 def datetime_cast_date_sql(self, sql, params, tzname):
100 sql, params = self._convert_sql_to_tz(sql, params, tzname)
101 return f"DATE({sql})", params
102
103 def datetime_cast_time_sql(self, sql, params, tzname):
104 sql, params = self._convert_sql_to_tz(sql, params, tzname)
105 return f"TIME({sql})", params
106
107 def datetime_extract_sql(self, lookup_type, sql, params, tzname):
108 sql, params = self._convert_sql_to_tz(sql, params, tzname)
109 return self.date_extract_sql(lookup_type, sql, params)
110
111 def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
112 sql, params = self._convert_sql_to_tz(sql, params, tzname)
113 fields = ["year", "month", "day", "hour", "minute", "second"]
114 format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s")
115 format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
116 if lookup_type == "quarter":
117 return (
118 f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + "
119 f"INTERVAL QUARTER({sql}) QUARTER - "
120 f"INTERVAL 1 QUARTER, %s) AS DATETIME)"
121 ), (*params, *params, "%Y-%m-01 00:00:00")
122 if lookup_type == "week":
123 return (
124 f"CAST(DATE_FORMAT("
125 f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)"
126 ), (*params, *params, "%Y-%m-%d 00:00:00")
127 try:
128 i = fields.index(lookup_type) + 1
129 except ValueError:
130 pass
131 else:
132 format_str = "".join(format[:i] + format_def[i:])
133 return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str)
134 return sql, params
135
136 def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
137 sql, params = self._convert_sql_to_tz(sql, params, tzname)
138 fields = {
139 "hour": "%H:00:00",
140 "minute": "%H:%i:00",
141 "second": "%H:%i:%s",
142 }
143 if lookup_type in fields:
144 format_str = fields[lookup_type]
145 return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str)
146 else:
147 return f"TIME({sql})", params
148
149 def fetch_returned_insert_rows(self, cursor):
150 """
151 Given a cursor object that has just performed an INSERT...RETURNING
152 statement into a table, return the tuple of returned data.
153 """
154 return cursor.fetchall()
155
156 def format_for_duration_arithmetic(self, sql):
157 return f"INTERVAL {sql} MICROSECOND"
158
159 def force_no_ordering(self):
160 """
161 "ORDER BY NULL" prevents MySQL from implicitly ordering by grouped
162 columns. If no ordering would otherwise be applied, we don't want any
163 implicit sorting going on.
164 """
165 return [(None, ("NULL", [], False))]
166
167 def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
168 return value
169
170 def last_executed_query(self, cursor, sql, params):
171 # With MySQLdb, cursor objects have an (undocumented) "_executed"
172 # attribute where the exact query sent to the database is saved.
173 # See MySQLdb/cursors.py in the source distribution.
174 # MySQLdb returns string, PyMySQL bytes.
175 return force_str(getattr(cursor, "_executed", None), errors="replace")
176
177 def no_limit_value(self):
178 # 2**64 - 1, as recommended by the MySQL documentation
179 return 18446744073709551615
180
181 def quote_name(self, name):
182 if name.startswith("`") and name.endswith("`"):
183 return name # Quoting once is enough.
184 return f"`{name}`"
185
186 def return_insert_columns(self, fields):
187 # MySQL and MariaDB < 10.5.0 don't support an INSERT...RETURNING
188 # statement.
189 if not fields:
190 return "", ()
191 columns = [
192 f"{self.quote_name(field.model._meta.db_table)}.{self.quote_name(field.column)}"
193 for field in fields
194 ]
195 return "RETURNING {}".format(", ".join(columns)), ()
196
197 def validate_autopk_value(self, value):
198 # Zero in AUTO_INCREMENT field does not work without the
199 # NO_AUTO_VALUE_ON_ZERO SQL mode.
200 if value == 0 and not self.connection.features.allows_auto_pk_0:
201 raise ValueError(
202 "The database backend does not accept 0 as a value for PrimaryKeyField."
203 )
204 return value
205
206 def adapt_datetimefield_value(self, value):
207 if value is None:
208 return None
209
210 # Expression values are adapted by the database.
211 if hasattr(value, "resolve_expression"):
212 return value
213
214 # MySQL doesn't support tz-aware datetimes
215 if timezone.is_aware(value):
216 value = timezone.make_naive(value, self.connection.timezone)
217 return str(value)
218
219 def adapt_timefield_value(self, value):
220 if value is None:
221 return None
222
223 # Expression values are adapted by the database.
224 if hasattr(value, "resolve_expression"):
225 return value
226
227 # MySQL doesn't support tz-aware times
228 if timezone.is_aware(value):
229 raise ValueError("MySQL backend does not support timezone-aware times.")
230
231 return value.isoformat(timespec="microseconds")
232
233 def max_name_length(self):
234 return 64
235
236 def pk_default_value(self):
237 return "NULL"
238
239 def bulk_insert_sql(self, fields, placeholder_rows):
240 placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
241 values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
242 return "VALUES " + values_sql
243
244 def combine_expression(self, connector, sub_expressions):
245 if connector == "^":
246 return "POW({})".format(",".join(sub_expressions))
247 # Convert the result to a signed integer since MySQL's binary operators
248 # return an unsigned integer.
249 elif connector in ("&", "|", "<<", "#"):
250 connector = "^" if connector == "#" else connector
251 return f"CONVERT({connector.join(sub_expressions)}, SIGNED)"
252 elif connector == ">>":
253 lhs, rhs = sub_expressions
254 return f"FLOOR({lhs} / POW(2, {rhs}))"
255 return super().combine_expression(connector, sub_expressions)
256
257 def get_db_converters(self, expression):
258 converters = super().get_db_converters(expression)
259 internal_type = expression.output_field.get_internal_type()
260 if internal_type == "BooleanField":
261 converters.append(self.convert_booleanfield_value)
262 elif internal_type == "DateTimeField":
263 converters.append(self.convert_datetimefield_value)
264 elif internal_type == "UUIDField":
265 converters.append(self.convert_uuidfield_value)
266 return converters
267
268 def convert_booleanfield_value(self, value, expression, connection):
269 if value in (0, 1):
270 value = bool(value)
271 return value
272
273 def convert_datetimefield_value(self, value, expression, connection):
274 if value is not None:
275 value = timezone.make_aware(value, self.connection.timezone)
276 return value
277
278 def convert_uuidfield_value(self, value, expression, connection):
279 if value is not None:
280 value = uuid.UUID(value)
281 return value
282
283 def binary_placeholder_sql(self, value):
284 return (
285 "_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s"
286 )
287
288 def subtract_temporals(self, internal_type, lhs, rhs):
289 lhs_sql, lhs_params = lhs
290 rhs_sql, rhs_params = rhs
291 if internal_type == "TimeField":
292 if self.connection.mysql_is_mariadb:
293 # MariaDB includes the microsecond component in TIME_TO_SEC as
294 # a decimal. MySQL returns an integer without microseconds.
295 return (
296 f"CAST((TIME_TO_SEC({lhs_sql}) - TIME_TO_SEC({rhs_sql})) "
297 "* 1000000 AS SIGNED)"
298 ), (
299 *lhs_params,
300 *rhs_params,
301 )
302 return (
303 f"((TIME_TO_SEC({lhs_sql}) * 1000000 + MICROSECOND({lhs_sql})) -"
304 f" (TIME_TO_SEC({rhs_sql}) * 1000000 + MICROSECOND({rhs_sql})))"
305 ), tuple(lhs_params) * 2 + tuple(rhs_params) * 2
306 params = (*rhs_params, *lhs_params)
307 return f"TIMESTAMPDIFF(MICROSECOND, {rhs_sql}, {lhs_sql})", params
308
309 def explain_query_prefix(self, format=None, **options):
310 # Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
311 if format and format.upper() == "TEXT":
312 format = "TRADITIONAL"
313 elif (
314 not format and "TREE" in self.connection.features.supported_explain_formats
315 ):
316 # Use TREE by default (if supported) as it's more informative.
317 format = "TREE"
318 analyze = options.pop("analyze", False)
319 prefix = super().explain_query_prefix(format, **options)
320 if analyze and self.connection.features.supports_explain_analyze:
321 # MariaDB uses ANALYZE instead of EXPLAIN ANALYZE.
322 prefix = (
323 "ANALYZE" if self.connection.mysql_is_mariadb else prefix + " ANALYZE"
324 )
325 if format and not (analyze and not self.connection.mysql_is_mariadb):
326 # Only MariaDB supports the analyze option with formats.
327 prefix += f" FORMAT={format}"
328 return prefix
329
330 def regex_lookup(self, lookup_type):
331 # REGEXP_LIKE doesn't exist in MariaDB.
332 if self.connection.mysql_is_mariadb:
333 if lookup_type == "regex":
334 return "%s REGEXP BINARY %s"
335 return "%s REGEXP %s"
336
337 match_option = "c" if lookup_type == "regex" else "i"
338 return f"REGEXP_LIKE(%s, %s, '{match_option}')"
339
340 def insert_statement(self, on_conflict=None):
341 if on_conflict == OnConflict.IGNORE:
342 return "INSERT IGNORE INTO"
343 return super().insert_statement(on_conflict=on_conflict)
344
345 def lookup_cast(self, lookup_type, internal_type=None):
346 lookup = "%s"
347 if internal_type == "JSONField":
348 if self.connection.mysql_is_mariadb or lookup_type in (
349 "iexact",
350 "contains",
351 "icontains",
352 "startswith",
353 "istartswith",
354 "endswith",
355 "iendswith",
356 "regex",
357 "iregex",
358 ):
359 lookup = "JSON_UNQUOTE(%s)"
360 return lookup
361
362 def conditional_expression_supported_in_where_clause(self, expression):
363 # MySQL ignores indexes with boolean fields unless they're compared
364 # directly to a boolean value.
365 if isinstance(expression, Exists | Lookup):
366 return True
367 if isinstance(expression, ExpressionWrapper) and expression.conditional:
368 return self.conditional_expression_supported_in_where_clause(
369 expression.expression
370 )
371 if getattr(expression, "conditional", False):
372 return False
373 return super().conditional_expression_supported_in_where_clause(expression)
374
375 def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
376 if on_conflict == OnConflict.UPDATE:
377 conflict_suffix_sql = "ON DUPLICATE KEY UPDATE %(fields)s"
378 # The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
379 # aliases for the new row and its columns available in MySQL
380 # 8.0.19+.
381 if not self.connection.mysql_is_mariadb:
382 if self.connection.mysql_version >= (8, 0, 19):
383 conflict_suffix_sql = f"AS new {conflict_suffix_sql}"
384 field_sql = "%(field)s = new.%(field)s"
385 else:
386 field_sql = "%(field)s = VALUES(%(field)s)"
387 # Use VALUE() on MariaDB.
388 else:
389 field_sql = "%(field)s = VALUE(%(field)s)"
390
391 fields = ", ".join(
392 [
393 field_sql % {"field": field}
394 for field in map(self.quote_name, update_fields)
395 ]
396 )
397 return conflict_suffix_sql % {"fields": fields}
398 return super().on_conflict_suffix_sql(
399 fields,
400 on_conflict,
401 update_fields,
402 unique_fields,
403 )