1from __future__ import annotations
2
3import datetime
4import uuid
5from collections.abc import Iterable
6from typing import TYPE_CHECKING, Any
7
8from plain.models.backends.base.operations import BaseDatabaseOperations
9from plain.models.backends.utils import split_tzname_delta
10from plain.models.constants import OnConflict
11from plain.models.expressions import Exists, ExpressionWrapper
12from plain.models.lookups import Lookup
13from plain.utils import timezone
14from plain.utils.encoding import force_str
15from plain.utils.regex_helper import _lazy_re_compile
16
17if TYPE_CHECKING:
18 from plain.models.backends.base.base import BaseDatabaseWrapper
19 from plain.models.backends.mysql.base import MySQLDatabaseWrapper
20 from plain.models.fields import Field
21
22
23class DatabaseOperations(BaseDatabaseOperations):
24 # Type checker hint: connection is always MySQLDatabaseWrapper in this class
25 connection: MySQLDatabaseWrapper
26
27 compiler_module = "plain.models.backends.mysql.compiler"
28
29 # MySQL stores positive fields as UNSIGNED ints.
30 integer_field_ranges = {
31 **BaseDatabaseOperations.integer_field_ranges,
32 "PositiveSmallIntegerField": (0, 65535),
33 "PositiveIntegerField": (0, 4294967295),
34 "PositiveBigIntegerField": (0, 18446744073709551615),
35 }
36 cast_data_types = {
37 "PrimaryKeyField": "signed integer",
38 "CharField": "char(%(max_length)s)",
39 "DecimalField": "decimal(%(max_digits)s, %(decimal_places)s)",
40 "TextField": "char",
41 "IntegerField": "signed integer",
42 "BigIntegerField": "signed integer",
43 "SmallIntegerField": "signed integer",
44 "PositiveBigIntegerField": "unsigned integer",
45 "PositiveIntegerField": "unsigned integer",
46 "PositiveSmallIntegerField": "unsigned integer",
47 "DurationField": "signed integer",
48 }
49 cast_char_field_without_max_length = "char"
50 explain_prefix = "EXPLAIN"
51
52 # EXTRACT format cannot be passed in parameters.
53 _extract_format_re = _lazy_re_compile(r"[A-Z_]+")
54
55 def date_extract_sql(
56 self, lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
57 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
58 # https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
59 if lookup_type == "week_day":
60 # DAYOFWEEK() returns an integer, 1-7, Sunday=1.
61 return f"DAYOFWEEK({sql})", params
62 elif lookup_type == "iso_week_day":
63 # WEEKDAY() returns an integer, 0-6, Monday=0.
64 return f"WEEKDAY({sql}) + 1", params
65 elif lookup_type == "week":
66 # Override the value of default_week_format for consistency with
67 # other database backends.
68 # Mode 3: Monday, 1-53, with 4 or more days this year.
69 return f"WEEK({sql}, 3)", params
70 elif lookup_type == "iso_year":
71 # Get the year part from the YEARWEEK function, which returns a
72 # number as year * 100 + week.
73 return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params
74 else:
75 # EXTRACT returns 1-53 based on ISO-8601 for the week number.
76 lookup_type = lookup_type.upper()
77 if not self._extract_format_re.fullmatch(lookup_type):
78 raise ValueError(f"Invalid loookup type: {lookup_type!r}")
79 return f"EXTRACT({lookup_type} FROM {sql})", params
80
81 def date_trunc_sql(
82 self,
83 lookup_type: str,
84 sql: str,
85 params: list[Any] | tuple[Any, ...],
86 tzname: str | None = None,
87 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
88 sql, params = self._convert_sql_to_tz(sql, params, tzname)
89 fields = {
90 "year": "%Y-01-01",
91 "month": "%Y-%m-01",
92 }
93 if lookup_type in fields:
94 format_str = fields[lookup_type]
95 return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str)
96 elif lookup_type == "quarter":
97 return (
98 f"MAKEDATE(YEAR({sql}), 1) + "
99 f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER",
100 (*params, *params),
101 )
102 elif lookup_type == "week":
103 return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params)
104 else:
105 return f"DATE({sql})", params
106
107 def _prepare_tzname_delta(self, tzname: str) -> str:
108 tzname, sign, offset = split_tzname_delta(tzname)
109 return f"{sign}{offset}" if offset else tzname
110
111 def _convert_sql_to_tz(
112 self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
113 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
114 if tzname and self.connection.timezone_name != tzname:
115 return f"CONVERT_TZ({sql}, %s, %s)", (
116 *params,
117 self.connection.timezone_name,
118 self._prepare_tzname_delta(tzname),
119 )
120 return sql, params
121
122 def datetime_cast_date_sql(
123 self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
124 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
125 sql, params = self._convert_sql_to_tz(sql, params, tzname)
126 return f"DATE({sql})", params
127
128 def datetime_cast_time_sql(
129 self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
130 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
131 sql, params = self._convert_sql_to_tz(sql, params, tzname)
132 return f"TIME({sql})", params
133
134 def datetime_extract_sql(
135 self,
136 lookup_type: str,
137 sql: str,
138 params: list[Any] | tuple[Any, ...],
139 tzname: str | None,
140 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
141 sql, params = self._convert_sql_to_tz(sql, params, tzname)
142 return self.date_extract_sql(lookup_type, sql, params)
143
144 def datetime_trunc_sql(
145 self,
146 lookup_type: str,
147 sql: str,
148 params: list[Any] | tuple[Any, ...],
149 tzname: str | None,
150 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
151 sql, params = self._convert_sql_to_tz(sql, params, tzname)
152 fields = ["year", "month", "day", "hour", "minute", "second"]
153 format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s")
154 format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
155 if lookup_type == "quarter":
156 return (
157 f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + "
158 f"INTERVAL QUARTER({sql}) QUARTER - "
159 f"INTERVAL 1 QUARTER, %s) AS DATETIME)"
160 ), (*params, *params, "%Y-%m-01 00:00:00")
161 if lookup_type == "week":
162 return (
163 f"CAST(DATE_FORMAT("
164 f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)"
165 ), (*params, *params, "%Y-%m-%d 00:00:00")
166 try:
167 i = fields.index(lookup_type) + 1
168 except ValueError:
169 pass
170 else:
171 format_str = "".join(format[:i] + format_def[i:])
172 return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str)
173 return sql, params
174
175 def time_trunc_sql(
176 self,
177 lookup_type: str,
178 sql: str,
179 params: list[Any] | tuple[Any, ...],
180 tzname: str | None = None,
181 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
182 sql, params = self._convert_sql_to_tz(sql, params, tzname)
183 fields = {
184 "hour": "%H:00:00",
185 "minute": "%H:%i:00",
186 "second": "%H:%i:%s",
187 }
188 if lookup_type in fields:
189 format_str = fields[lookup_type]
190 return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str)
191 else:
192 return f"TIME({sql})", params
193
194 def fetch_returned_insert_rows(self, cursor: Any) -> list[Any]:
195 """
196 Given a cursor object that has just performed an INSERT...RETURNING
197 statement into a table, return the tuple of returned data.
198 """
199 return cursor.fetchall()
200
201 def format_for_duration_arithmetic(self, sql: str) -> str:
202 return f"INTERVAL {sql} MICROSECOND"
203
204 def force_no_ordering(self) -> list[tuple[None, tuple[str, list[Any], bool]]]:
205 """
206 "ORDER BY NULL" prevents MySQL from implicitly ordering by grouped
207 columns. If no ordering would otherwise be applied, we don't want any
208 implicit sorting going on.
209 """
210 return [(None, ("NULL", [], False))]
211
212 def adapt_decimalfield_value(
213 self,
214 value: Any,
215 max_digits: int | None = None,
216 decimal_places: int | None = None,
217 ) -> Any:
218 return value
219
220 def last_executed_query(self, cursor: Any, sql: str, params: Any) -> str | None:
221 # With MySQLdb, cursor objects have an (undocumented) "_executed"
222 # attribute where the exact query sent to the database is saved.
223 # See MySQLdb/cursors.py in the source distribution.
224 # MySQLdb returns string, PyMySQL bytes.
225 return force_str(getattr(cursor, "_executed", None), errors="replace")
226
227 def no_limit_value(self) -> int:
228 # 2**64 - 1, as recommended by the MySQL documentation
229 return 18446744073709551615
230
231 def quote_name(self, name: str) -> str:
232 if name.startswith("`") and name.endswith("`"):
233 return name # Quoting once is enough.
234 return f"`{name}`"
235
236 def return_insert_columns(self, fields: list[Any]) -> tuple[str, tuple[Any, ...]]:
237 # MySQL and MariaDB < 10.5.0 don't support an INSERT...RETURNING
238 # statement.
239 if not fields:
240 return "", ()
241 columns = [
242 f"{self.quote_name(field.model.model_options.db_table)}.{self.quote_name(field.column)}"
243 for field in fields
244 ]
245 return "RETURNING {}".format(", ".join(columns)), ()
246
247 def validate_autopk_value(self, value: int) -> int:
248 # Zero in AUTO_INCREMENT field does not work without the
249 # NO_AUTO_VALUE_ON_ZERO SQL mode.
250 if value == 0 and not self.connection.features.allows_auto_pk_0:
251 raise ValueError(
252 "The database backend does not accept 0 as a value for PrimaryKeyField."
253 )
254 return value
255
256 def adapt_datetimefield_value(
257 self, value: datetime.datetime | Any | None
258 ) -> str | Any | None:
259 if value is None:
260 return None
261
262 # Expression values are adapted by the database.
263 if hasattr(value, "resolve_expression"):
264 return value
265
266 # MySQL doesn't support tz-aware datetimes
267 if timezone.is_aware(value):
268 value = timezone.make_naive(value, self.connection.timezone)
269 return str(value)
270
271 def adapt_timefield_value(
272 self, value: datetime.time | Any | None
273 ) -> str | Any | None:
274 if value is None:
275 return None
276
277 # Expression values are adapted by the database.
278 if hasattr(value, "resolve_expression"):
279 return value
280
281 # MySQL doesn't support tz-aware times
282 if timezone.is_aware(value): # type: ignore[arg-type]
283 raise ValueError("MySQL backend does not support timezone-aware times.")
284
285 return value.isoformat(timespec="microseconds")
286
287 def max_name_length(self) -> int:
288 return 64
289
290 def pk_default_value(self) -> str:
291 return "NULL"
292
293 def bulk_insert_sql(
294 self, fields: list[Any], placeholder_rows: list[list[str]]
295 ) -> str:
296 placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
297 values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
298 return "VALUES " + values_sql
299
300 def combine_expression(self, connector: str, sub_expressions: list[str]) -> str:
301 if connector == "^":
302 return "POW({})".format(",".join(sub_expressions))
303 # Convert the result to a signed integer since MySQL's binary operators
304 # return an unsigned integer.
305 elif connector in ("&", "|", "<<", "#"):
306 connector = "^" if connector == "#" else connector
307 return f"CONVERT({connector.join(sub_expressions)}, SIGNED)"
308 elif connector == ">>":
309 lhs, rhs = sub_expressions
310 return f"FLOOR({lhs} / POW(2, {rhs}))"
311 return super().combine_expression(connector, sub_expressions)
312
313 def get_db_converters(self, expression: Any) -> list[Any]:
314 converters = super().get_db_converters(expression)
315 internal_type = expression.output_field.get_internal_type()
316 if internal_type == "BooleanField":
317 converters.append(self.convert_booleanfield_value)
318 elif internal_type == "DateTimeField":
319 converters.append(self.convert_datetimefield_value)
320 elif internal_type == "UUIDField":
321 converters.append(self.convert_uuidfield_value)
322 return converters
323
324 def convert_booleanfield_value(
325 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
326 ) -> Any:
327 if value in (0, 1):
328 value = bool(value)
329 return value
330
331 def convert_datetimefield_value(
332 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
333 ) -> datetime.datetime | None:
334 if value is not None:
335 value = timezone.make_aware(value, self.connection.timezone)
336 return value
337
338 def convert_uuidfield_value(
339 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
340 ) -> uuid.UUID | None:
341 if value is not None:
342 value = uuid.UUID(value)
343 return value
344
345 def binary_placeholder_sql(self, value: Any) -> str:
346 return (
347 "_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s"
348 )
349
350 def subtract_temporals(
351 self,
352 internal_type: str,
353 lhs: tuple[str, list[Any] | tuple[Any, ...]],
354 rhs: tuple[str, list[Any] | tuple[Any, ...]],
355 ) -> tuple[str, tuple[Any, ...]]:
356 lhs_sql, lhs_params = lhs
357 rhs_sql, rhs_params = rhs
358 if internal_type == "TimeField":
359 if self.connection.mysql_is_mariadb:
360 # MariaDB includes the microsecond component in TIME_TO_SEC as
361 # a decimal. MySQL returns an integer without microseconds.
362 return (
363 f"CAST((TIME_TO_SEC({lhs_sql}) - TIME_TO_SEC({rhs_sql})) "
364 "* 1000000 AS SIGNED)"
365 ), (
366 *lhs_params,
367 *rhs_params,
368 )
369 return (
370 f"((TIME_TO_SEC({lhs_sql}) * 1000000 + MICROSECOND({lhs_sql})) -"
371 f" (TIME_TO_SEC({rhs_sql}) * 1000000 + MICROSECOND({rhs_sql})))"
372 ), tuple(lhs_params) * 2 + tuple(rhs_params) * 2
373 params = (*rhs_params, *lhs_params)
374 return f"TIMESTAMPDIFF(MICROSECOND, {rhs_sql}, {lhs_sql})", params
375
376 def explain_query_prefix(self, format: str | None = None, **options: Any) -> str:
377 # Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
378 if format and format.upper() == "TEXT":
379 format = "TRADITIONAL"
380 elif (
381 not format and "TREE" in self.connection.features.supported_explain_formats
382 ):
383 # Use TREE by default (if supported) as it's more informative.
384 format = "TREE"
385 analyze = options.pop("analyze", False)
386 prefix = super().explain_query_prefix(format, **options)
387 if analyze and self.connection.features.supports_explain_analyze:
388 # MariaDB uses ANALYZE instead of EXPLAIN ANALYZE.
389 prefix = (
390 "ANALYZE" if self.connection.mysql_is_mariadb else prefix + " ANALYZE"
391 )
392 if format and not (analyze and not self.connection.mysql_is_mariadb):
393 # Only MariaDB supports the analyze option with formats.
394 prefix += f" FORMAT={format}"
395 return prefix
396
397 def regex_lookup(self, lookup_type: str) -> str:
398 # REGEXP_LIKE doesn't exist in MariaDB.
399 if self.connection.mysql_is_mariadb:
400 if lookup_type == "regex":
401 return "%s REGEXP BINARY %s"
402 return "%s REGEXP %s"
403
404 match_option = "c" if lookup_type == "regex" else "i"
405 return f"REGEXP_LIKE(%s, %s, '{match_option}')"
406
407 def insert_statement(self, on_conflict: Any = None) -> str:
408 if on_conflict == OnConflict.IGNORE:
409 return "INSERT IGNORE INTO"
410 return super().insert_statement(on_conflict=on_conflict)
411
412 def lookup_cast(self, lookup_type: str, internal_type: str | None = None) -> str:
413 lookup = "%s"
414 if internal_type == "JSONField":
415 if self.connection.mysql_is_mariadb or lookup_type in (
416 "iexact",
417 "contains",
418 "icontains",
419 "startswith",
420 "istartswith",
421 "endswith",
422 "iendswith",
423 "regex",
424 "iregex",
425 ):
426 lookup = "JSON_UNQUOTE(%s)"
427 return lookup
428
429 def conditional_expression_supported_in_where_clause(self, expression: Any) -> bool:
430 # MySQL ignores indexes with boolean fields unless they're compared
431 # directly to a boolean value.
432 if isinstance(expression, Exists | Lookup):
433 return True
434 if isinstance(expression, ExpressionWrapper) and expression.conditional:
435 return self.conditional_expression_supported_in_where_clause(
436 expression.expression
437 )
438 if getattr(expression, "conditional", False):
439 return False
440 return super().conditional_expression_supported_in_where_clause(expression)
441
442 def on_conflict_suffix_sql(
443 self,
444 fields: list[Field],
445 on_conflict: Any,
446 update_fields: Iterable[str],
447 unique_fields: Iterable[str],
448 ) -> str:
449 if on_conflict == OnConflict.UPDATE:
450 conflict_suffix_sql = "ON DUPLICATE KEY UPDATE %(fields)s"
451 # The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
452 # aliases for the new row and its columns available in MySQL
453 # 8.0.19+.
454 if not self.connection.mysql_is_mariadb:
455 if self.connection.mysql_version >= (8, 0, 19):
456 conflict_suffix_sql = f"AS new {conflict_suffix_sql}"
457 field_sql = "%(field)s = new.%(field)s"
458 else:
459 field_sql = "%(field)s = VALUES(%(field)s)"
460 # Use VALUE() on MariaDB.
461 else:
462 field_sql = "%(field)s = VALUE(%(field)s)"
463
464 fields_str = ", ".join(
465 [
466 field_sql % {"field": field}
467 for field in map(self.quote_name, update_fields)
468 ]
469 )
470 return conflict_suffix_sql % {"fields": fields_str}
471 return super().on_conflict_suffix_sql(
472 fields,
473 on_conflict,
474 update_fields,
475 unique_fields,
476 )