1from __future__ import annotations
2
3import datetime
4import uuid
5from collections.abc import Iterable
6from functools import cached_property
7from typing import TYPE_CHECKING, Any
8
9from plain.models.backends.base.operations import BaseDatabaseOperations
10from plain.models.backends.utils import CursorWrapper, split_tzname_delta
11from plain.models.constants import OnConflict
12from plain.models.expressions import Exists, ExpressionWrapper, ResolvableExpression
13from plain.models.lookups import Lookup
14from plain.utils import timezone
15from plain.utils.encoding import force_str
16from plain.utils.regex_helper import _lazy_re_compile
17
18if TYPE_CHECKING:
19 from plain.models.backends.base.base import BaseDatabaseWrapper
20 from plain.models.backends.mysql.base import MySQLDatabaseWrapper
21 from plain.models.fields import Field
22 from plain.models.sql.compiler import SQLCompiler
23 from plain.models.sql.query import Query
24
25
26class DatabaseOperations(BaseDatabaseOperations):
27 # Type checker hint: connection is always MySQLDatabaseWrapper in this class
28 connection: MySQLDatabaseWrapper
29
30 @cached_property
31 def compilers(self) -> dict[type[Query], type[SQLCompiler]]:
32 from plain.models.backends.mysql import compiler as mysql_compiler
33 from plain.models.sql.query import Query
34 from plain.models.sql.subqueries import (
35 AggregateQuery,
36 DeleteQuery,
37 InsertQuery,
38 UpdateQuery,
39 )
40
41 return {
42 Query: mysql_compiler.SQLCompiler,
43 DeleteQuery: mysql_compiler.SQLDeleteCompiler,
44 UpdateQuery: mysql_compiler.SQLUpdateCompiler,
45 InsertQuery: mysql_compiler.SQLInsertCompiler,
46 AggregateQuery: mysql_compiler.SQLAggregateCompiler,
47 }
48
49 # MySQL stores positive fields as UNSIGNED ints.
50 integer_field_ranges = {
51 **BaseDatabaseOperations.integer_field_ranges,
52 "PositiveSmallIntegerField": (0, 65535),
53 "PositiveIntegerField": (0, 4294967295),
54 "PositiveBigIntegerField": (0, 18446744073709551615),
55 }
56 cast_data_types = {
57 "PrimaryKeyField": "signed integer",
58 "CharField": "char(%(max_length)s)",
59 "DecimalField": "decimal(%(max_digits)s, %(decimal_places)s)",
60 "TextField": "char",
61 "IntegerField": "signed integer",
62 "BigIntegerField": "signed integer",
63 "SmallIntegerField": "signed integer",
64 "PositiveBigIntegerField": "unsigned integer",
65 "PositiveIntegerField": "unsigned integer",
66 "PositiveSmallIntegerField": "unsigned integer",
67 "DurationField": "signed integer",
68 }
69 cast_char_field_without_max_length = "char"
70 explain_prefix = "EXPLAIN"
71
72 # EXTRACT format cannot be passed in parameters.
73 _extract_format_re = _lazy_re_compile(r"[A-Z_]+")
74
75 def date_extract_sql(
76 self, lookup_type: str, sql: str, params: list[Any] | tuple[Any, ...]
77 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
78 # https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
79 if lookup_type == "week_day":
80 # DAYOFWEEK() returns an integer, 1-7, Sunday=1.
81 return f"DAYOFWEEK({sql})", params
82 elif lookup_type == "iso_week_day":
83 # WEEKDAY() returns an integer, 0-6, Monday=0.
84 return f"WEEKDAY({sql}) + 1", params
85 elif lookup_type == "week":
86 # Override the value of default_week_format for consistency with
87 # other database backends.
88 # Mode 3: Monday, 1-53, with 4 or more days this year.
89 return f"WEEK({sql}, 3)", params
90 elif lookup_type == "iso_year":
91 # Get the year part from the YEARWEEK function, which returns a
92 # number as year * 100 + week.
93 return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params
94 else:
95 # EXTRACT returns 1-53 based on ISO-8601 for the week number.
96 lookup_type = lookup_type.upper()
97 if not self._extract_format_re.fullmatch(lookup_type):
98 raise ValueError(f"Invalid loookup type: {lookup_type!r}")
99 return f"EXTRACT({lookup_type} FROM {sql})", params
100
101 def date_trunc_sql(
102 self,
103 lookup_type: str,
104 sql: str,
105 params: list[Any] | tuple[Any, ...],
106 tzname: str | None = None,
107 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
108 sql, params = self._convert_sql_to_tz(sql, params, tzname)
109 fields = {
110 "year": "%Y-01-01",
111 "month": "%Y-%m-01",
112 }
113 if lookup_type in fields:
114 format_str = fields[lookup_type]
115 return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str)
116 elif lookup_type == "quarter":
117 return (
118 f"MAKEDATE(YEAR({sql}), 1) + "
119 f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER",
120 (*params, *params),
121 )
122 elif lookup_type == "week":
123 return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params)
124 else:
125 return f"DATE({sql})", params
126
127 def _prepare_tzname_delta(self, tzname: str) -> str:
128 tzname, sign, offset = split_tzname_delta(tzname)
129 return f"{sign}{offset}" if offset else tzname
130
131 def _convert_sql_to_tz(
132 self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
133 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
134 if tzname and self.connection.timezone_name != tzname:
135 return f"CONVERT_TZ({sql}, %s, %s)", (
136 *params,
137 self.connection.timezone_name,
138 self._prepare_tzname_delta(tzname),
139 )
140 return sql, params
141
142 def datetime_cast_date_sql(
143 self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
144 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
145 sql, params = self._convert_sql_to_tz(sql, params, tzname)
146 return f"DATE({sql})", params
147
148 def datetime_cast_time_sql(
149 self, sql: str, params: list[Any] | tuple[Any, ...], tzname: str | None
150 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
151 sql, params = self._convert_sql_to_tz(sql, params, tzname)
152 return f"TIME({sql})", params
153
154 def datetime_extract_sql(
155 self,
156 lookup_type: str,
157 sql: str,
158 params: list[Any] | tuple[Any, ...],
159 tzname: str | None,
160 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
161 sql, params = self._convert_sql_to_tz(sql, params, tzname)
162 return self.date_extract_sql(lookup_type, sql, params)
163
164 def datetime_trunc_sql(
165 self,
166 lookup_type: str,
167 sql: str,
168 params: list[Any] | tuple[Any, ...],
169 tzname: str | None,
170 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
171 sql, params = self._convert_sql_to_tz(sql, params, tzname)
172 fields = ["year", "month", "day", "hour", "minute", "second"]
173 format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s")
174 format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
175 if lookup_type == "quarter":
176 return (
177 f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + "
178 f"INTERVAL QUARTER({sql}) QUARTER - "
179 f"INTERVAL 1 QUARTER, %s) AS DATETIME)"
180 ), (*params, *params, "%Y-%m-01 00:00:00")
181 if lookup_type == "week":
182 return (
183 f"CAST(DATE_FORMAT("
184 f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)"
185 ), (*params, *params, "%Y-%m-%d 00:00:00")
186 try:
187 i = fields.index(lookup_type) + 1
188 except ValueError:
189 pass
190 else:
191 format_str = "".join(format[:i] + format_def[i:])
192 return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str)
193 return sql, params
194
195 def time_trunc_sql(
196 self,
197 lookup_type: str,
198 sql: str,
199 params: list[Any] | tuple[Any, ...],
200 tzname: str | None = None,
201 ) -> tuple[str, list[Any] | tuple[Any, ...]]:
202 sql, params = self._convert_sql_to_tz(sql, params, tzname)
203 fields = {
204 "hour": "%H:00:00",
205 "minute": "%H:%i:00",
206 "second": "%H:%i:%s",
207 }
208 if lookup_type in fields:
209 format_str = fields[lookup_type]
210 return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str)
211 else:
212 return f"TIME({sql})", params
213
214 def fetch_returned_insert_rows(self, cursor: CursorWrapper) -> list[Any]:
215 """
216 Given a cursor object that has just performed an INSERT...RETURNING
217 statement into a table, return the tuple of returned data.
218 """
219 return cursor.fetchall()
220
221 def format_for_duration_arithmetic(self, sql: str) -> str:
222 return f"INTERVAL {sql} MICROSECOND"
223
224 def force_no_ordering(
225 self,
226 ) -> list[tuple[Any, tuple[str, tuple[Any, ...], bool]]]:
227 """
228 "ORDER BY NULL" prevents MySQL from implicitly ordering by grouped
229 columns. If no ordering would otherwise be applied, we don't want any
230 implicit sorting going on.
231 """
232 return [(None, ("NULL", (), False))]
233
234 def adapt_decimalfield_value(
235 self,
236 value: Any,
237 max_digits: int | None = None,
238 decimal_places: int | None = None,
239 ) -> Any:
240 return value
241
242 def last_executed_query(
243 self, cursor: CursorWrapper, sql: str, params: Any
244 ) -> str | None:
245 # With MySQLdb, cursor objects have an (undocumented) "_executed"
246 # attribute where the exact query sent to the database is saved.
247 # See MySQLdb/cursors.py in the source distribution.
248 # MySQLdb returns string, PyMySQL bytes.
249 return force_str(getattr(cursor, "_executed", None), errors="replace")
250
251 def no_limit_value(self) -> int:
252 # 2**64 - 1, as recommended by the MySQL documentation
253 return 18446744073709551615
254
255 def quote_name(self, name: str) -> str:
256 if name.startswith("`") and name.endswith("`"):
257 return name # Quoting once is enough.
258 return f"`{name}`"
259
260 def return_insert_columns(self, fields: list[Any]) -> tuple[str, tuple[Any, ...]]:
261 # MySQL and MariaDB < 10.5.0 don't support an INSERT...RETURNING
262 # statement.
263 if not fields:
264 return "", ()
265 columns = [
266 f"{self.quote_name(field.model.model_options.db_table)}.{self.quote_name(field.column)}"
267 for field in fields
268 ]
269 return "RETURNING {}".format(", ".join(columns)), ()
270
271 def validate_autopk_value(self, value: int) -> int:
272 # Zero in AUTO_INCREMENT field does not work without the
273 # NO_AUTO_VALUE_ON_ZERO SQL mode.
274 if value == 0 and not self.connection.features.allows_auto_pk_0:
275 raise ValueError(
276 "The database backend does not accept 0 as a value for PrimaryKeyField."
277 )
278 return value
279
280 def adapt_datetimefield_value(
281 self, value: datetime.datetime | Any | None
282 ) -> str | Any | None:
283 if value is None:
284 return None
285
286 # Expression values are adapted by the database.
287 if isinstance(value, ResolvableExpression):
288 return value
289
290 # MySQL doesn't support tz-aware datetimes
291 if timezone.is_aware(value):
292 value = timezone.make_naive(value, self.connection.timezone)
293 return str(value)
294
295 def adapt_timefield_value(
296 self, value: datetime.time | Any | None
297 ) -> str | Any | None:
298 if value is None:
299 return None
300
301 # Expression values are adapted by the database.
302 if isinstance(value, ResolvableExpression):
303 return value
304
305 # MySQL doesn't support tz-aware times
306 if timezone.is_aware(value):
307 raise ValueError("MySQL backend does not support timezone-aware times.")
308
309 return value.isoformat(timespec="microseconds")
310
311 def max_name_length(self) -> int:
312 return 64
313
314 def pk_default_value(self) -> str:
315 return "NULL"
316
317 def bulk_insert_sql(
318 self, fields: list[Any], placeholder_rows: list[list[str]]
319 ) -> str:
320 placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
321 values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
322 return "VALUES " + values_sql
323
324 def combine_expression(self, connector: str, sub_expressions: list[str]) -> str:
325 if connector == "^":
326 return "POW({})".format(",".join(sub_expressions))
327 # Convert the result to a signed integer since MySQL's binary operators
328 # return an unsigned integer.
329 elif connector in ("&", "|", "<<", "#"):
330 connector = "^" if connector == "#" else connector
331 return f"CONVERT({connector.join(sub_expressions)}, SIGNED)"
332 elif connector == ">>":
333 lhs, rhs = sub_expressions
334 return f"FLOOR({lhs} / POW(2, {rhs}))"
335 return super().combine_expression(connector, sub_expressions)
336
337 def get_db_converters(self, expression: Any) -> list[Any]:
338 converters = super().get_db_converters(expression)
339 internal_type = expression.output_field.get_internal_type()
340 if internal_type == "BooleanField":
341 converters.append(self.convert_booleanfield_value)
342 elif internal_type == "DateTimeField":
343 converters.append(self.convert_datetimefield_value)
344 elif internal_type == "UUIDField":
345 converters.append(self.convert_uuidfield_value)
346 return converters
347
348 def convert_booleanfield_value(
349 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
350 ) -> Any:
351 if value in (0, 1):
352 value = bool(value)
353 return value
354
355 def convert_datetimefield_value(
356 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
357 ) -> datetime.datetime | None:
358 if value is not None:
359 value = timezone.make_aware(value, self.connection.timezone)
360 return value
361
362 def convert_uuidfield_value(
363 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
364 ) -> uuid.UUID | None:
365 if value is not None:
366 value = uuid.UUID(value)
367 return value
368
369 def binary_placeholder_sql(self, value: Any) -> str:
370 return (
371 "_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s"
372 )
373
374 def subtract_temporals(
375 self,
376 internal_type: str,
377 lhs: tuple[str, list[Any] | tuple[Any, ...]],
378 rhs: tuple[str, list[Any] | tuple[Any, ...]],
379 ) -> tuple[str, tuple[Any, ...]]:
380 lhs_sql, lhs_params = lhs
381 rhs_sql, rhs_params = rhs
382 if internal_type == "TimeField":
383 if self.connection.mysql_is_mariadb:
384 # MariaDB includes the microsecond component in TIME_TO_SEC as
385 # a decimal. MySQL returns an integer without microseconds.
386 return (
387 f"CAST((TIME_TO_SEC({lhs_sql}) - TIME_TO_SEC({rhs_sql})) "
388 "* 1000000 AS SIGNED)"
389 ), (
390 *lhs_params,
391 *rhs_params,
392 )
393 return (
394 f"((TIME_TO_SEC({lhs_sql}) * 1000000 + MICROSECOND({lhs_sql})) -"
395 f" (TIME_TO_SEC({rhs_sql}) * 1000000 + MICROSECOND({rhs_sql})))"
396 ), tuple(lhs_params) * 2 + tuple(rhs_params) * 2
397 params = (*rhs_params, *lhs_params)
398 return f"TIMESTAMPDIFF(MICROSECOND, {rhs_sql}, {lhs_sql})", params
399
400 def explain_query_prefix(self, format: str | None = None, **options: Any) -> str:
401 # Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
402 if format and format.upper() == "TEXT":
403 format = "TRADITIONAL"
404 elif (
405 not format and "TREE" in self.connection.features.supported_explain_formats
406 ):
407 # Use TREE by default (if supported) as it's more informative.
408 format = "TREE"
409 analyze = options.pop("analyze", False)
410 prefix = super().explain_query_prefix(format, **options)
411 if analyze and self.connection.features.supports_explain_analyze:
412 # MariaDB uses ANALYZE instead of EXPLAIN ANALYZE.
413 prefix = (
414 "ANALYZE" if self.connection.mysql_is_mariadb else prefix + " ANALYZE"
415 )
416 if format and not (analyze and not self.connection.mysql_is_mariadb):
417 # Only MariaDB supports the analyze option with formats.
418 prefix += f" FORMAT={format}"
419 return prefix
420
421 def regex_lookup(self, lookup_type: str) -> str:
422 # REGEXP_LIKE doesn't exist in MariaDB.
423 if self.connection.mysql_is_mariadb:
424 if lookup_type == "regex":
425 return "%s REGEXP BINARY %s"
426 return "%s REGEXP %s"
427
428 match_option = "c" if lookup_type == "regex" else "i"
429 return f"REGEXP_LIKE(%s, %s, '{match_option}')"
430
431 def insert_statement(self, on_conflict: Any = None) -> str:
432 if on_conflict == OnConflict.IGNORE:
433 return "INSERT IGNORE INTO"
434 return super().insert_statement(on_conflict=on_conflict)
435
436 def lookup_cast(self, lookup_type: str, internal_type: str | None = None) -> str:
437 lookup = "%s"
438 if internal_type == "JSONField":
439 if self.connection.mysql_is_mariadb or lookup_type in (
440 "iexact",
441 "contains",
442 "icontains",
443 "startswith",
444 "istartswith",
445 "endswith",
446 "iendswith",
447 "regex",
448 "iregex",
449 ):
450 lookup = "JSON_UNQUOTE(%s)"
451 return lookup
452
453 def conditional_expression_supported_in_where_clause(self, expression: Any) -> bool:
454 # MySQL ignores indexes with boolean fields unless they're compared
455 # directly to a boolean value.
456 if isinstance(expression, Exists | Lookup):
457 return True
458 if isinstance(expression, ExpressionWrapper) and expression.conditional:
459 return self.conditional_expression_supported_in_where_clause(
460 expression.expression
461 )
462 if getattr(expression, "conditional", False):
463 return False
464 return super().conditional_expression_supported_in_where_clause(expression)
465
466 def on_conflict_suffix_sql(
467 self,
468 fields: list[Field],
469 on_conflict: Any,
470 update_fields: Iterable[str],
471 unique_fields: Iterable[str],
472 ) -> str:
473 if on_conflict == OnConflict.UPDATE:
474 conflict_suffix_sql = "ON DUPLICATE KEY UPDATE %(fields)s"
475 # The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
476 # aliases for the new row and its columns available in MySQL
477 # 8.0.19+.
478 if not self.connection.mysql_is_mariadb:
479 if self.connection.mysql_version >= (8, 0, 19):
480 conflict_suffix_sql = f"AS new {conflict_suffix_sql}"
481 field_sql = "%(field)s = new.%(field)s"
482 else:
483 field_sql = "%(field)s = VALUES(%(field)s)"
484 # Use VALUE() on MariaDB.
485 else:
486 field_sql = "%(field)s = VALUE(%(field)s)"
487
488 fields_str = ", ".join(
489 [
490 field_sql % {"field": field}
491 for field in map(self.quote_name, update_fields)
492 ]
493 )
494 return conflict_suffix_sql % {"fields": fields_str}
495 return super().on_conflict_suffix_sql(
496 fields,
497 on_conflict,
498 update_fields,
499 unique_fields,
500 )