1from __future__ import annotations
2
3import datetime
4import uuid
5from typing import TYPE_CHECKING, Any
6
7from plain.models.backends.base.operations import BaseDatabaseOperations
8from plain.models.backends.utils import split_tzname_delta
9from plain.models.constants import OnConflict
10from plain.models.expressions import Exists, ExpressionWrapper
11from plain.models.lookups import Lookup
12from plain.utils import timezone
13from plain.utils.encoding import force_str
14from plain.utils.regex_helper import _lazy_re_compile
15
16if TYPE_CHECKING:
17 from plain.models.backends.base.base import BaseDatabaseWrapper
18
19
20class DatabaseOperations(BaseDatabaseOperations):
21 compiler_module = "plain.models.backends.mysql.compiler"
22
23 # MySQL stores positive fields as UNSIGNED ints.
24 integer_field_ranges = {
25 **BaseDatabaseOperations.integer_field_ranges,
26 "PositiveSmallIntegerField": (0, 65535),
27 "PositiveIntegerField": (0, 4294967295),
28 "PositiveBigIntegerField": (0, 18446744073709551615),
29 }
30 cast_data_types = {
31 "PrimaryKeyField": "signed integer",
32 "CharField": "char(%(max_length)s)",
33 "DecimalField": "decimal(%(max_digits)s, %(decimal_places)s)",
34 "TextField": "char",
35 "IntegerField": "signed integer",
36 "BigIntegerField": "signed integer",
37 "SmallIntegerField": "signed integer",
38 "PositiveBigIntegerField": "unsigned integer",
39 "PositiveIntegerField": "unsigned integer",
40 "PositiveSmallIntegerField": "unsigned integer",
41 "DurationField": "signed integer",
42 }
43 cast_char_field_without_max_length = "char"
44 explain_prefix = "EXPLAIN"
45
46 # EXTRACT format cannot be passed in parameters.
47 _extract_format_re = _lazy_re_compile(r"[A-Z_]+")
48
49 def date_extract_sql(
50 self, lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
51 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
52 # https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
53 if lookup_type == "week_day":
54 # DAYOFWEEK() returns an integer, 1-7, Sunday=1.
55 return f"DAYOFWEEK({sql})", params
56 elif lookup_type == "iso_week_day":
57 # WEEKDAY() returns an integer, 0-6, Monday=0.
58 return f"WEEKDAY({sql}) + 1", params
59 elif lookup_type == "week":
60 # Override the value of default_week_format for consistency with
61 # other database backends.
62 # Mode 3: Monday, 1-53, with 4 or more days this year.
63 return f"WEEK({sql}, 3)", params
64 elif lookup_type == "iso_year":
65 # Get the year part from the YEARWEEK function, which returns a
66 # number as year * 100 + week.
67 return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params
68 else:
69 # EXTRACT returns 1-53 based on ISO-8601 for the week number.
70 lookup_type = lookup_type.upper()
71 if not self._extract_format_re.fullmatch(lookup_type):
72 raise ValueError(f"Invalid loookup type: {lookup_type!r}")
73 return f"EXTRACT({lookup_type} FROM {sql})", params
74
75 def date_trunc_sql(
76 self,
77 lookup_type: str,
78 sql: str,
79 params: list[Any] | tuple[Any, ...],
80 tzname: str | None = None,
81 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
82 sql, params = self._convert_sql_to_tz(sql, params, tzname)
83 fields = {
84 "year": "%Y-01-01",
85 "month": "%Y-%m-01",
86 }
87 if lookup_type in fields:
88 format_str = fields[lookup_type]
89 return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str)
90 elif lookup_type == "quarter":
91 return (
92 f"MAKEDATE(YEAR({sql}), 1) + "
93 f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER",
94 (*params, *params),
95 )
96 elif lookup_type == "week":
97 return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params)
98 else:
99 return f"DATE({sql})", params
100
101 def _prepare_tzname_delta(self, tzname: str) -> str:
102 tzname, sign, offset = split_tzname_delta(tzname)
103 return f"{sign}{offset}" if offset else tzname
104
105 def _convert_sql_to_tz(
106 self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
107 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
108 if tzname and self.connection.timezone_name != tzname:
109 return f"CONVERT_TZ({sql}, %s, %s)", (
110 *params,
111 self.connection.timezone_name,
112 self._prepare_tzname_delta(tzname),
113 )
114 return sql, params
115
116 def datetime_cast_date_sql(
117 self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
118 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
119 sql, params = self._convert_sql_to_tz(sql, params, tzname)
120 return f"DATE({sql})", params
121
122 def datetime_cast_time_sql(
123 self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
124 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
125 sql, params = self._convert_sql_to_tz(sql, params, tzname)
126 return f"TIME({sql})", params
127
128 def datetime_extract_sql(
129 self,
130 lookup_type: str,
131 sql: str,
132 params: list[Any] | tuple[Any, ...],
133 tzname: str | None,
134 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
135 sql, params = self._convert_sql_to_tz(sql, params, tzname)
136 return self.date_extract_sql(lookup_type, sql, params)
137
138 def datetime_trunc_sql(
139 self,
140 lookup_type: str,
141 sql: str,
142 params: list[Any] | tuple[Any, ...],
143 tzname: str | None,
144 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
145 sql, params = self._convert_sql_to_tz(sql, params, tzname)
146 fields = ["year", "month", "day", "hour", "minute", "second"]
147 format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s")
148 format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
149 if lookup_type == "quarter":
150 return (
151 f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + "
152 f"INTERVAL QUARTER({sql}) QUARTER - "
153 f"INTERVAL 1 QUARTER, %s) AS DATETIME)"
154 ), (*params, *params, "%Y-%m-01 00:00:00")
155 if lookup_type == "week":
156 return (
157 f"CAST(DATE_FORMAT("
158 f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)"
159 ), (*params, *params, "%Y-%m-%d 00:00:00")
160 try:
161 i = fields.index(lookup_type) + 1
162 except ValueError:
163 pass
164 else:
165 format_str = "".join(format[:i] + format_def[i:])
166 return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str)
167 return sql, params
168
169 def time_trunc_sql(
170 self,
171 lookup_type: str,
172 sql: str,
173 params: list[Any] | tuple[Any, ...],
174 tzname: str | None = None,
175 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
176 sql, params = self._convert_sql_to_tz(sql, params, tzname)
177 fields = {
178 "hour": "%H:00:00",
179 "minute": "%H:%i:00",
180 "second": "%H:%i:%s",
181 }
182 if lookup_type in fields:
183 format_str = fields[lookup_type]
184 return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str)
185 else:
186 return f"TIME({sql})", params
187
188 def fetch_returned_insert_rows(self, cursor: Any) -> list[Any]:
189 """
190 Given a cursor object that has just performed an INSERT...RETURNING
191 statement into a table, return the tuple of returned data.
192 """
193 return cursor.fetchall()
194
195 def format_for_duration_arithmetic(self, sql: str) -> str:
196 return f"INTERVAL {sql} MICROSECOND"
197
198 def force_no_ordering(self) -> list[tuple[None, tuple[str, list[Any], bool]]]:
199 """
200 "ORDER BY NULL" prevents MySQL from implicitly ordering by grouped
201 columns. If no ordering would otherwise be applied, we don't want any
202 implicit sorting going on.
203 """
204 return [(None, ("NULL", [], False))]
205
206 def adapt_decimalfield_value(
207 self,
208 value: Any,
209 max_digits: int | None = None,
210 decimal_places: int | None = None,
211 ) -> Any:
212 return value
213
214 def last_executed_query(self, cursor: Any, sql: str, params: Any) -> str | None:
215 # With MySQLdb, cursor objects have an (undocumented) "_executed"
216 # attribute where the exact query sent to the database is saved.
217 # See MySQLdb/cursors.py in the source distribution.
218 # MySQLdb returns string, PyMySQL bytes.
219 return force_str(getattr(cursor, "_executed", None), errors="replace")
220
221 def no_limit_value(self) -> int:
222 # 2**64 - 1, as recommended by the MySQL documentation
223 return 18446744073709551615
224
225 def quote_name(self, name: str) -> str:
226 if name.startswith("`") and name.endswith("`"):
227 return name # Quoting once is enough.
228 return f"`{name}`"
229
230 def return_insert_columns(self, fields: list[Any]) -> tuple[str, tuple[Any, ...]]:
231 # MySQL and MariaDB < 10.5.0 don't support an INSERT...RETURNING
232 # statement.
233 if not fields:
234 return "", ()
235 columns = [
236 f"{self.quote_name(field.model.model_options.db_table)}.{self.quote_name(field.column)}"
237 for field in fields
238 ]
239 return "RETURNING {}".format(", ".join(columns)), ()
240
241 def validate_autopk_value(self, value: int) -> int:
242 # Zero in AUTO_INCREMENT field does not work without the
243 # NO_AUTO_VALUE_ON_ZERO SQL mode.
244 if value == 0 and not self.connection.features.allows_auto_pk_0:
245 raise ValueError(
246 "The database backend does not accept 0 as a value for PrimaryKeyField."
247 )
248 return value
249
250 def adapt_datetimefield_value(
251 self, value: datetime.datetime | Any | None
252 ) -> str | Any | None:
253 if value is None:
254 return None
255
256 # Expression values are adapted by the database.
257 if hasattr(value, "resolve_expression"):
258 return value
259
260 # MySQL doesn't support tz-aware datetimes
261 if timezone.is_aware(value):
262 value = timezone.make_naive(value, self.connection.timezone)
263 return str(value)
264
265 def adapt_timefield_value(
266 self, value: datetime.time | Any | None
267 ) -> str | Any | None:
268 if value is None:
269 return None
270
271 # Expression values are adapted by the database.
272 if hasattr(value, "resolve_expression"):
273 return value
274
275 # MySQL doesn't support tz-aware times
276 if timezone.is_aware(value): # type: ignore[arg-type]
277 raise ValueError("MySQL backend does not support timezone-aware times.")
278
279 return value.isoformat(timespec="microseconds")
280
281 def max_name_length(self) -> int:
282 return 64
283
284 def pk_default_value(self) -> str:
285 return "NULL"
286
287 def bulk_insert_sql(
288 self, fields: list[Any], placeholder_rows: list[list[str]]
289 ) -> str:
290 placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
291 values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
292 return "VALUES " + values_sql
293
294 def combine_expression(self, connector: str, sub_expressions: list[str]) -> str:
295 if connector == "^":
296 return "POW({})".format(",".join(sub_expressions))
297 # Convert the result to a signed integer since MySQL's binary operators
298 # return an unsigned integer.
299 elif connector in ("&", "|", "<<", "#"):
300 connector = "^" if connector == "#" else connector
301 return f"CONVERT({connector.join(sub_expressions)}, SIGNED)"
302 elif connector == ">>":
303 lhs, rhs = sub_expressions
304 return f"FLOOR({lhs} / POW(2, {rhs}))"
305 return super().combine_expression(connector, sub_expressions)
306
307 def get_db_converters(self, expression: Any) -> list[Any]:
308 converters = super().get_db_converters(expression)
309 internal_type = expression.output_field.get_internal_type()
310 if internal_type == "BooleanField":
311 converters.append(self.convert_booleanfield_value)
312 elif internal_type == "DateTimeField":
313 converters.append(self.convert_datetimefield_value)
314 elif internal_type == "UUIDField":
315 converters.append(self.convert_uuidfield_value)
316 return converters
317
318 def convert_booleanfield_value(
319 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
320 ) -> Any:
321 if value in (0, 1):
322 value = bool(value)
323 return value
324
325 def convert_datetimefield_value(
326 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
327 ) -> datetime.datetime | None:
328 if value is not None:
329 value = timezone.make_aware(value, self.connection.timezone)
330 return value
331
332 def convert_uuidfield_value(
333 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
334 ) -> uuid.UUID | None:
335 if value is not None:
336 value = uuid.UUID(value)
337 return value
338
339 def binary_placeholder_sql(self, value: Any) -> str:
340 return (
341 "_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s"
342 )
343
344 def subtract_temporals(
345 self,
346 internal_type: str,
347 lhs: tuple[str, list[Any] | tuple[Any, ...]],
348 rhs: tuple[str, list[Any] | tuple[Any, ...]],
349 ) -> tuple[str, tuple[Any, ...]]:
350 lhs_sql, lhs_params = lhs
351 rhs_sql, rhs_params = rhs
352 if internal_type == "TimeField":
353 if self.connection.mysql_is_mariadb:
354 # MariaDB includes the microsecond component in TIME_TO_SEC as
355 # a decimal. MySQL returns an integer without microseconds.
356 return (
357 f"CAST((TIME_TO_SEC({lhs_sql}) - TIME_TO_SEC({rhs_sql})) "
358 "* 1000000 AS SIGNED)"
359 ), (
360 *lhs_params,
361 *rhs_params,
362 )
363 return (
364 f"((TIME_TO_SEC({lhs_sql}) * 1000000 + MICROSECOND({lhs_sql})) -"
365 f" (TIME_TO_SEC({rhs_sql}) * 1000000 + MICROSECOND({rhs_sql})))"
366 ), tuple(lhs_params) * 2 + tuple(rhs_params) * 2
367 params = (*rhs_params, *lhs_params)
368 return f"TIMESTAMPDIFF(MICROSECOND, {rhs_sql}, {lhs_sql})", params
369
370 def explain_query_prefix(self, format: str | None = None, **options: Any) -> str:
371 # Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
372 if format and format.upper() == "TEXT":
373 format = "TRADITIONAL"
374 elif (
375 not format and "TREE" in self.connection.features.supported_explain_formats
376 ):
377 # Use TREE by default (if supported) as it's more informative.
378 format = "TREE"
379 analyze = options.pop("analyze", False)
380 prefix = super().explain_query_prefix(format, **options)
381 if analyze and self.connection.features.supports_explain_analyze:
382 # MariaDB uses ANALYZE instead of EXPLAIN ANALYZE.
383 prefix = (
384 "ANALYZE" if self.connection.mysql_is_mariadb else prefix + " ANALYZE"
385 )
386 if format and not (analyze and not self.connection.mysql_is_mariadb):
387 # Only MariaDB supports the analyze option with formats.
388 prefix += f" FORMAT={format}"
389 return prefix
390
391 def regex_lookup(self, lookup_type: str) -> str:
392 # REGEXP_LIKE doesn't exist in MariaDB.
393 if self.connection.mysql_is_mariadb:
394 if lookup_type == "regex":
395 return "%s REGEXP BINARY %s"
396 return "%s REGEXP %s"
397
398 match_option = "c" if lookup_type == "regex" else "i"
399 return f"REGEXP_LIKE(%s, %s, '{match_option}')"
400
401 def insert_statement(self, on_conflict: Any = None) -> str:
402 if on_conflict == OnConflict.IGNORE:
403 return "INSERT IGNORE INTO"
404 return super().insert_statement(on_conflict=on_conflict)
405
406 def lookup_cast(self, lookup_type: str, internal_type: str | None = None) -> str:
407 lookup = "%s"
408 if internal_type == "JSONField":
409 if self.connection.mysql_is_mariadb or lookup_type in (
410 "iexact",
411 "contains",
412 "icontains",
413 "startswith",
414 "istartswith",
415 "endswith",
416 "iendswith",
417 "regex",
418 "iregex",
419 ):
420 lookup = "JSON_UNQUOTE(%s)"
421 return lookup
422
423 def conditional_expression_supported_in_where_clause(self, expression: Any) -> bool:
424 # MySQL ignores indexes with boolean fields unless they're compared
425 # directly to a boolean value.
426 if isinstance(expression, Exists | Lookup):
427 return True
428 if isinstance(expression, ExpressionWrapper) and expression.conditional:
429 return self.conditional_expression_supported_in_where_clause(
430 expression.expression
431 )
432 if getattr(expression, "conditional", False):
433 return False
434 return super().conditional_expression_supported_in_where_clause(expression)
435
436 def on_conflict_suffix_sql(
437 self,
438 fields: list[Any],
439 on_conflict: Any,
440 update_fields: list[Any],
441 unique_fields: list[Any],
442 ) -> str:
443 if on_conflict == OnConflict.UPDATE:
444 conflict_suffix_sql = "ON DUPLICATE KEY UPDATE %(fields)s"
445 # The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
446 # aliases for the new row and its columns available in MySQL
447 # 8.0.19+.
448 if not self.connection.mysql_is_mariadb:
449 if self.connection.mysql_version >= (8, 0, 19):
450 conflict_suffix_sql = f"AS new {conflict_suffix_sql}"
451 field_sql = "%(field)s = new.%(field)s"
452 else:
453 field_sql = "%(field)s = VALUES(%(field)s)"
454 # Use VALUE() on MariaDB.
455 else:
456 field_sql = "%(field)s = VALUE(%(field)s)"
457
458 fields_str = ", ".join(
459 [
460 field_sql % {"field": field}
461 for field in map(self.quote_name, update_fields)
462 ]
463 )
464 return conflict_suffix_sql % {"fields": fields_str}
465 return super().on_conflict_suffix_sql(
466 fields,
467 on_conflict,
468 update_fields,
469 unique_fields,
470 )