1from __future__ import annotations
2
3import datetime
4import uuid
5from collections.abc import Iterable
6from typing import TYPE_CHECKING, Any
7
8from plain.models.backends.base.operations import BaseDatabaseOperations
9from plain.models.backends.utils import CursorWrapper, split_tzname_delta
10from plain.models.constants import OnConflict
11from plain.models.expressions import Exists, ExpressionWrapper, ResolvableExpression
12from plain.models.lookups import Lookup
13from plain.utils import timezone
14from plain.utils.encoding import force_str
15from plain.utils.regex_helper import _lazy_re_compile
16
17if TYPE_CHECKING:
18 from plain.models.backends.base.base import BaseDatabaseWrapper
19 from plain.models.backends.mysql.base import MySQLDatabaseWrapper
20 from plain.models.fields import Field
21
22
23class DatabaseOperations(BaseDatabaseOperations):
24 # Type checker hint: connection is always MySQLDatabaseWrapper in this class
25 connection: MySQLDatabaseWrapper
26
27 compiler_module = "plain.models.backends.mysql.compiler"
28
29 # MySQL stores positive fields as UNSIGNED ints.
30 integer_field_ranges = {
31 **BaseDatabaseOperations.integer_field_ranges,
32 "PositiveSmallIntegerField": (0, 65535),
33 "PositiveIntegerField": (0, 4294967295),
34 "PositiveBigIntegerField": (0, 18446744073709551615),
35 }
36 cast_data_types = {
37 "PrimaryKeyField": "signed integer",
38 "CharField": "char(%(max_length)s)",
39 "DecimalField": "decimal(%(max_digits)s, %(decimal_places)s)",
40 "TextField": "char",
41 "IntegerField": "signed integer",
42 "BigIntegerField": "signed integer",
43 "SmallIntegerField": "signed integer",
44 "PositiveBigIntegerField": "unsigned integer",
45 "PositiveIntegerField": "unsigned integer",
46 "PositiveSmallIntegerField": "unsigned integer",
47 "DurationField": "signed integer",
48 }
49 cast_char_field_without_max_length = "char"
50 explain_prefix = "EXPLAIN"
51
52 # EXTRACT format cannot be passed in parameters.
53 _extract_format_re = _lazy_re_compile(r"[A-Z_]+")
54
55 def date_extract_sql(
56 self, lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
57 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
58 # https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
59 if lookup_type == "week_day":
60 # DAYOFWEEK() returns an integer, 1-7, Sunday=1.
61 return f"DAYOFWEEK({sql})", params
62 elif lookup_type == "iso_week_day":
63 # WEEKDAY() returns an integer, 0-6, Monday=0.
64 return f"WEEKDAY({sql}) + 1", params
65 elif lookup_type == "week":
66 # Override the value of default_week_format for consistency with
67 # other database backends.
68 # Mode 3: Monday, 1-53, with 4 or more days this year.
69 return f"WEEK({sql}, 3)", params
70 elif lookup_type == "iso_year":
71 # Get the year part from the YEARWEEK function, which returns a
72 # number as year * 100 + week.
73 return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params
74 else:
75 # EXTRACT returns 1-53 based on ISO-8601 for the week number.
76 lookup_type = lookup_type.upper()
77 if not self._extract_format_re.fullmatch(lookup_type):
78 raise ValueError(f"Invalid loookup type: {lookup_type!r}")
79 return f"EXTRACT({lookup_type} FROM {sql})", params
80
81 def date_trunc_sql(
82 self,
83 lookup_type: str,
84 sql: str,
85 params: list[Any] | tuple[Any, ...],
86 tzname: str | None = None,
87 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
88 sql, params = self._convert_sql_to_tz(sql, params, tzname)
89 fields = {
90 "year": "%Y-01-01",
91 "month": "%Y-%m-01",
92 }
93 if lookup_type in fields:
94 format_str = fields[lookup_type]
95 return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str)
96 elif lookup_type == "quarter":
97 return (
98 f"MAKEDATE(YEAR({sql}), 1) + "
99 f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER",
100 (*params, *params),
101 )
102 elif lookup_type == "week":
103 return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params)
104 else:
105 return f"DATE({sql})", params
106
107 def _prepare_tzname_delta(self, tzname: str) -> str:
108 tzname, sign, offset = split_tzname_delta(tzname)
109 return f"{sign}{offset}" if offset else tzname
110
111 def _convert_sql_to_tz(
112 self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
113 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
114 if tzname and self.connection.timezone_name != tzname:
115 return f"CONVERT_TZ({sql}, %s, %s)", (
116 *params,
117 self.connection.timezone_name,
118 self._prepare_tzname_delta(tzname),
119 )
120 return sql, params
121
122 def datetime_cast_date_sql(
123 self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
124 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
125 sql, params = self._convert_sql_to_tz(sql, params, tzname)
126 return f"DATE({sql})", params
127
128 def datetime_cast_time_sql(
129 self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
130 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
131 sql, params = self._convert_sql_to_tz(sql, params, tzname)
132 return f"TIME({sql})", params
133
134 def datetime_extract_sql(
135 self,
136 lookup_type: str,
137 sql: str,
138 params: list[Any] | tuple[Any, ...],
139 tzname: str | None,
140 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
141 sql, params = self._convert_sql_to_tz(sql, params, tzname)
142 return self.date_extract_sql(lookup_type, sql, params)
143
144 def datetime_trunc_sql(
145 self,
146 lookup_type: str,
147 sql: str,
148 params: list[Any] | tuple[Any, ...],
149 tzname: str | None,
150 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
151 sql, params = self._convert_sql_to_tz(sql, params, tzname)
152 fields = ["year", "month", "day", "hour", "minute", "second"]
153 format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s")
154 format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
155 if lookup_type == "quarter":
156 return (
157 f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + "
158 f"INTERVAL QUARTER({sql}) QUARTER - "
159 f"INTERVAL 1 QUARTER, %s) AS DATETIME)"
160 ), (*params, *params, "%Y-%m-01 00:00:00")
161 if lookup_type == "week":
162 return (
163 f"CAST(DATE_FORMAT("
164 f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)"
165 ), (*params, *params, "%Y-%m-%d 00:00:00")
166 try:
167 i = fields.index(lookup_type) + 1
168 except ValueError:
169 pass
170 else:
171 format_str = "".join(format[:i] + format_def[i:])
172 return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str)
173 return sql, params
174
175 def time_trunc_sql(
176 self,
177 lookup_type: str,
178 sql: str,
179 params: list[Any] | tuple[Any, ...],
180 tzname: str | None = None,
181 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
182 sql, params = self._convert_sql_to_tz(sql, params, tzname)
183 fields = {
184 "hour": "%H:00:00",
185 "minute": "%H:%i:00",
186 "second": "%H:%i:%s",
187 }
188 if lookup_type in fields:
189 format_str = fields[lookup_type]
190 return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str)
191 else:
192 return f"TIME({sql})", params
193
194 def fetch_returned_insert_rows(self, cursor: CursorWrapper) -> list[Any]:
195 """
196 Given a cursor object that has just performed an INSERT...RETURNING
197 statement into a table, return the tuple of returned data.
198 """
199 return cursor.fetchall()
200
201 def format_for_duration_arithmetic(self, sql: str) -> str:
202 return f"INTERVAL {sql} MICROSECOND"
203
204 def force_no_ordering(self) -> list[tuple[None, tuple[str, list[Any], bool]]]:
205 """
206 "ORDER BY NULL" prevents MySQL from implicitly ordering by grouped
207 columns. If no ordering would otherwise be applied, we don't want any
208 implicit sorting going on.
209 """
210 return [(None, ("NULL", [], False))]
211
212 def adapt_decimalfield_value(
213 self,
214 value: Any,
215 max_digits: int | None = None,
216 decimal_places: int | None = None,
217 ) -> Any:
218 return value
219
220 def last_executed_query(
221 self, cursor: CursorWrapper, sql: str, params: Any
222 ) -> str | None:
223 # With MySQLdb, cursor objects have an (undocumented) "_executed"
224 # attribute where the exact query sent to the database is saved.
225 # See MySQLdb/cursors.py in the source distribution.
226 # MySQLdb returns string, PyMySQL bytes.
227 return force_str(getattr(cursor, "_executed", None), errors="replace")
228
229 def no_limit_value(self) -> int:
230 # 2**64 - 1, as recommended by the MySQL documentation
231 return 18446744073709551615
232
233 def quote_name(self, name: str) -> str:
234 if name.startswith("`") and name.endswith("`"):
235 return name # Quoting once is enough.
236 return f"`{name}`"
237
238 def return_insert_columns(self, fields: list[Any]) -> tuple[str, tuple[Any, ...]]:
239 # MySQL and MariaDB < 10.5.0 don't support an INSERT...RETURNING
240 # statement.
241 if not fields:
242 return "", ()
243 columns = [
244 f"{self.quote_name(field.model.model_options.db_table)}.{self.quote_name(field.column)}"
245 for field in fields
246 ]
247 return "RETURNING {}".format(", ".join(columns)), ()
248
249 def validate_autopk_value(self, value: int) -> int:
250 # Zero in AUTO_INCREMENT field does not work without the
251 # NO_AUTO_VALUE_ON_ZERO SQL mode.
252 if value == 0 and not self.connection.features.allows_auto_pk_0:
253 raise ValueError(
254 "The database backend does not accept 0 as a value for PrimaryKeyField."
255 )
256 return value
257
258 def adapt_datetimefield_value(
259 self, value: datetime.datetime | Any | None
260 ) -> str | Any | None:
261 if value is None:
262 return None
263
264 # Expression values are adapted by the database.
265 if isinstance(value, ResolvableExpression):
266 return value
267
268 # MySQL doesn't support tz-aware datetimes
269 if timezone.is_aware(value):
270 value = timezone.make_naive(value, self.connection.timezone)
271 return str(value)
272
273 def adapt_timefield_value(
274 self, value: datetime.time | Any | None
275 ) -> str | Any | None:
276 if value is None:
277 return None
278
279 # Expression values are adapted by the database.
280 if isinstance(value, ResolvableExpression):
281 return value
282
283 # MySQL doesn't support tz-aware times
284 if timezone.is_aware(value):
285 raise ValueError("MySQL backend does not support timezone-aware times.")
286
287 return value.isoformat(timespec="microseconds")
288
289 def max_name_length(self) -> int:
290 return 64
291
292 def pk_default_value(self) -> str:
293 return "NULL"
294
295 def bulk_insert_sql(
296 self, fields: list[Any], placeholder_rows: list[list[str]]
297 ) -> str:
298 placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
299 values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
300 return "VALUES " + values_sql
301
302 def combine_expression(self, connector: str, sub_expressions: list[str]) -> str:
303 if connector == "^":
304 return "POW({})".format(",".join(sub_expressions))
305 # Convert the result to a signed integer since MySQL's binary operators
306 # return an unsigned integer.
307 elif connector in ("&", "|", "<<", "#"):
308 connector = "^" if connector == "#" else connector
309 return f"CONVERT({connector.join(sub_expressions)}, SIGNED)"
310 elif connector == ">>":
311 lhs, rhs = sub_expressions
312 return f"FLOOR({lhs} / POW(2, {rhs}))"
313 return super().combine_expression(connector, sub_expressions)
314
315 def get_db_converters(self, expression: Any) -> list[Any]:
316 converters = super().get_db_converters(expression)
317 internal_type = expression.output_field.get_internal_type()
318 if internal_type == "BooleanField":
319 converters.append(self.convert_booleanfield_value)
320 elif internal_type == "DateTimeField":
321 converters.append(self.convert_datetimefield_value)
322 elif internal_type == "UUIDField":
323 converters.append(self.convert_uuidfield_value)
324 return converters
325
326 def convert_booleanfield_value(
327 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
328 ) -> Any:
329 if value in (0, 1):
330 value = bool(value)
331 return value
332
333 def convert_datetimefield_value(
334 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
335 ) -> datetime.datetime | None:
336 if value is not None:
337 value = timezone.make_aware(value, self.connection.timezone)
338 return value
339
340 def convert_uuidfield_value(
341 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
342 ) -> uuid.UUID | None:
343 if value is not None:
344 value = uuid.UUID(value)
345 return value
346
347 def binary_placeholder_sql(self, value: Any) -> str:
348 return (
349 "_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s"
350 )
351
352 def subtract_temporals(
353 self,
354 internal_type: str,
355 lhs: tuple[str, list[Any] | tuple[Any, ...]],
356 rhs: tuple[str, list[Any] | tuple[Any, ...]],
357 ) -> tuple[str, tuple[Any, ...]]:
358 lhs_sql, lhs_params = lhs
359 rhs_sql, rhs_params = rhs
360 if internal_type == "TimeField":
361 if self.connection.mysql_is_mariadb:
362 # MariaDB includes the microsecond component in TIME_TO_SEC as
363 # a decimal. MySQL returns an integer without microseconds.
364 return (
365 f"CAST((TIME_TO_SEC({lhs_sql}) - TIME_TO_SEC({rhs_sql})) "
366 "* 1000000 AS SIGNED)"
367 ), (
368 *lhs_params,
369 *rhs_params,
370 )
371 return (
372 f"((TIME_TO_SEC({lhs_sql}) * 1000000 + MICROSECOND({lhs_sql})) -"
373 f" (TIME_TO_SEC({rhs_sql}) * 1000000 + MICROSECOND({rhs_sql})))"
374 ), tuple(lhs_params) * 2 + tuple(rhs_params) * 2
375 params = (*rhs_params, *lhs_params)
376 return f"TIMESTAMPDIFF(MICROSECOND, {rhs_sql}, {lhs_sql})", params
377
378 def explain_query_prefix(self, format: str | None = None, **options: Any) -> str:
379 # Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
380 if format and format.upper() == "TEXT":
381 format = "TRADITIONAL"
382 elif (
383 not format and "TREE" in self.connection.features.supported_explain_formats
384 ):
385 # Use TREE by default (if supported) as it's more informative.
386 format = "TREE"
387 analyze = options.pop("analyze", False)
388 prefix = super().explain_query_prefix(format, **options)
389 if analyze and self.connection.features.supports_explain_analyze:
390 # MariaDB uses ANALYZE instead of EXPLAIN ANALYZE.
391 prefix = (
392 "ANALYZE" if self.connection.mysql_is_mariadb else prefix + " ANALYZE"
393 )
394 if format and not (analyze and not self.connection.mysql_is_mariadb):
395 # Only MariaDB supports the analyze option with formats.
396 prefix += f" FORMAT={format}"
397 return prefix
398
399 def regex_lookup(self, lookup_type: str) -> str:
400 # REGEXP_LIKE doesn't exist in MariaDB.
401 if self.connection.mysql_is_mariadb:
402 if lookup_type == "regex":
403 return "%s REGEXP BINARY %s"
404 return "%s REGEXP %s"
405
406 match_option = "c" if lookup_type == "regex" else "i"
407 return f"REGEXP_LIKE(%s, %s, '{match_option}')"
408
409 def insert_statement(self, on_conflict: Any = None) -> str:
410 if on_conflict == OnConflict.IGNORE:
411 return "INSERT IGNORE INTO"
412 return super().insert_statement(on_conflict=on_conflict)
413
414 def lookup_cast(self, lookup_type: str, internal_type: str | None = None) -> str:
415 lookup = "%s"
416 if internal_type == "JSONField":
417 if self.connection.mysql_is_mariadb or lookup_type in (
418 "iexact",
419 "contains",
420 "icontains",
421 "startswith",
422 "istartswith",
423 "endswith",
424 "iendswith",
425 "regex",
426 "iregex",
427 ):
428 lookup = "JSON_UNQUOTE(%s)"
429 return lookup
430
431 def conditional_expression_supported_in_where_clause(self, expression: Any) -> bool:
432 # MySQL ignores indexes with boolean fields unless they're compared
433 # directly to a boolean value.
434 if isinstance(expression, Exists | Lookup):
435 return True
436 if isinstance(expression, ExpressionWrapper) and expression.conditional:
437 return self.conditional_expression_supported_in_where_clause(
438 expression.expression
439 )
440 if getattr(expression, "conditional", False):
441 return False
442 return super().conditional_expression_supported_in_where_clause(expression)
443
444 def on_conflict_suffix_sql(
445 self,
446 fields: list[Field],
447 on_conflict: Any,
448 update_fields: Iterable[str],
449 unique_fields: Iterable[str],
450 ) -> str:
451 if on_conflict == OnConflict.UPDATE:
452 conflict_suffix_sql = "ON DUPLICATE KEY UPDATE %(fields)s"
453 # The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
454 # aliases for the new row and its columns available in MySQL
455 # 8.0.19+.
456 if not self.connection.mysql_is_mariadb:
457 if self.connection.mysql_version >= (8, 0, 19):
458 conflict_suffix_sql = f"AS new {conflict_suffix_sql}"
459 field_sql = "%(field)s = new.%(field)s"
460 else:
461 field_sql = "%(field)s = VALUES(%(field)s)"
462 # Use VALUE() on MariaDB.
463 else:
464 field_sql = "%(field)s = VALUE(%(field)s)"
465
466 fields_str = ", ".join(
467 [
468 field_sql % {"field": field}
469 for field in map(self.quote_name, update_fields)
470 ]
471 )
472 return conflict_suffix_sql % {"fields": fields_str}
473 return super().on_conflict_suffix_sql(
474 fields,
475 on_conflict,
476 update_fields,
477 unique_fields,
478 )