1"""
2MySQL database backend for Plain.
3
4Requires mysqlclient: https://pypi.org/project/mysqlclient/
5"""
6
7from __future__ import annotations
8
9from functools import cached_property
10from typing import Any
11
12import MySQLdb as Database
13from MySQLdb.constants import CLIENT, FIELD_TYPE
14from MySQLdb.converters import conversions
15
16from plain.exceptions import ImproperlyConfigured
17from plain.models.backends import utils as backend_utils
18from plain.models.backends.base.base import BaseDatabaseWrapper
19from plain.models.db import IntegrityError
20from plain.utils.regex_helper import _lazy_re_compile
21
22# With mysqlclient stubs, we can now type the connection
23try:
24 from MySQLdb.connections import Connection as MySQLConnection
25except ImportError:
26 MySQLConnection = Any # type: ignore[misc, assignment]
27
28from .client import DatabaseClient
29from .creation import DatabaseCreation
30from .features import DatabaseFeatures
31from .introspection import DatabaseIntrospection
32from .operations import DatabaseOperations
33from .schema import DatabaseSchemaEditor
34from .validation import DatabaseValidation
35
36# MySQLdb returns TIME columns as timedelta -- they are more like timedelta in
37# terms of actual behavior as they are signed and include days -- and Plain
38# expects time.
39plain_conversions = {
40 **conversions,
41 **{FIELD_TYPE.TIME: backend_utils.typecast_time},
42}
43
44# This should match the numerical portion of the version numbers (we can treat
45# versions like 5.0.24 and 5.0.24a as the same).
46server_version_re = _lazy_re_compile(r"(\d{1,2})\.(\d{1,2})\.(\d{1,2})")
47
48
49class CursorWrapper:
50 """
51 A thin wrapper around MySQLdb's normal cursor class that catches particular
52 exception instances and reraises them with the correct types.
53
54 Implemented as a wrapper, rather than a subclass, so that it isn't stuck
55 to the particular underlying representation returned by Connection.cursor().
56 """
57
58 codes_for_integrityerror = (
59 1048, # Column cannot be null
60 1690, # BIGINT UNSIGNED value is out of range
61 3819, # CHECK constraint is violated
62 4025, # CHECK constraint failed
63 )
64
65 def __init__(self, cursor: Any) -> None:
66 self.cursor = cursor
67
68 def execute(self, query: str, args: Any = None) -> int:
69 try:
70 # args is None means no string interpolation
71 return self.cursor.execute(query, args)
72 except Database.OperationalError as e:
73 # Map some error codes to IntegrityError, since they seem to be
74 # misclassified and Plain would prefer the more logical place.
75 if e.args[0] in self.codes_for_integrityerror:
76 raise IntegrityError(*tuple(e.args))
77 raise
78
79 def executemany(self, query: str, args: Any) -> int:
80 try:
81 return self.cursor.executemany(query, args)
82 except Database.OperationalError as e:
83 # Map some error codes to IntegrityError, since they seem to be
84 # misclassified and Plain would prefer the more logical place.
85 if e.args[0] in self.codes_for_integrityerror:
86 raise IntegrityError(*tuple(e.args))
87 raise
88
89 def __getattr__(self, attr: str) -> Any:
90 return getattr(self.cursor, attr)
91
92 def __iter__(self) -> Any:
93 return iter(self.cursor)
94
95
96class MySQLDatabaseWrapper(BaseDatabaseWrapper):
97 # Type checker hints: narrow base class attribute types to backend-specific classes
98 ops: DatabaseOperations
99 features: DatabaseFeatures
100 introspection: DatabaseIntrospection
101 creation: DatabaseCreation
102
103 vendor = "mysql"
104 # This dictionary maps Field objects to their associated MySQL column
105 # types, as strings. Column-type strings can contain format strings; they'll
106 # be interpolated against the values of Field.__dict__ before being output.
107 # If a column type is set to None, it won't be included in the output.
108 data_types = {
109 "PrimaryKeyField": "bigint AUTO_INCREMENT",
110 "BinaryField": "longblob",
111 "BooleanField": "bool",
112 "CharField": "varchar(%(max_length)s)",
113 "DateField": "date",
114 "DateTimeField": "datetime(6)",
115 "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
116 "DurationField": "bigint",
117 "FloatField": "double precision",
118 "IntegerField": "integer",
119 "BigIntegerField": "bigint",
120 "GenericIPAddressField": "char(39)",
121 "JSONField": "json",
122 "PositiveBigIntegerField": "bigint UNSIGNED",
123 "PositiveIntegerField": "integer UNSIGNED",
124 "PositiveSmallIntegerField": "smallint UNSIGNED",
125 "SmallIntegerField": "smallint",
126 "TextField": "longtext",
127 "TimeField": "time(6)",
128 "UUIDField": "char(32)",
129 }
130
131 # For these data types:
132 # - MySQL < 8.0.13 doesn't accept default values and implicitly treats them
133 # as nullable
134 # - all versions of MySQL and MariaDB don't support full width database
135 # indexes
136 _limited_data_types = (
137 "tinyblob",
138 "blob",
139 "mediumblob",
140 "longblob",
141 "tinytext",
142 "text",
143 "mediumtext",
144 "longtext",
145 "json",
146 )
147
148 operators = {
149 "exact": "= %s",
150 "iexact": "LIKE %s",
151 "contains": "LIKE BINARY %s",
152 "icontains": "LIKE %s",
153 "gt": "> %s",
154 "gte": ">= %s",
155 "lt": "< %s",
156 "lte": "<= %s",
157 "startswith": "LIKE BINARY %s",
158 "endswith": "LIKE BINARY %s",
159 "istartswith": "LIKE %s",
160 "iendswith": "LIKE %s",
161 }
162
163 # The patterns below are used to generate SQL pattern lookup clauses when
164 # the right-hand side of the lookup isn't a raw string (it might be an expression
165 # or the result of a bilateral transformation).
166 # In those cases, special characters for LIKE operators (e.g. \, *, _) should be
167 # escaped on database side.
168 #
169 # Note: we use str.format() here for readability as '%' is used as a wildcard for
170 # the LIKE operator.
171 pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')"
172 pattern_ops = {
173 "contains": "LIKE BINARY CONCAT('%%', {}, '%%')",
174 "icontains": "LIKE CONCAT('%%', {}, '%%')",
175 "startswith": "LIKE BINARY CONCAT({}, '%%')",
176 "istartswith": "LIKE CONCAT({}, '%%')",
177 "endswith": "LIKE BINARY CONCAT('%%', {})",
178 "iendswith": "LIKE CONCAT('%%', {})",
179 }
180
181 isolation_levels = {
182 "read uncommitted",
183 "read committed",
184 "repeatable read",
185 "serializable",
186 }
187
188 Database = Database
189 SchemaEditorClass = DatabaseSchemaEditor
190 # Classes instantiated in __init__().
191 client_class = DatabaseClient
192 creation_class = DatabaseCreation
193 features_class = DatabaseFeatures
194 introspection_class = DatabaseIntrospection
195 ops_class = DatabaseOperations
196 validation_class = DatabaseValidation
197
198 def get_database_version(self) -> tuple[int, int, int]:
199 return self.mysql_version
200
201 def get_connection_params(self) -> dict[str, Any]:
202 kwargs: dict[str, Any] = {
203 "conv": plain_conversions,
204 "charset": "utf8",
205 }
206 settings_dict = self.settings_dict
207 if settings_dict["USER"]:
208 kwargs["user"] = settings_dict["USER"]
209 if settings_dict["NAME"]:
210 kwargs["database"] = settings_dict["NAME"]
211 if settings_dict["PASSWORD"]:
212 kwargs["password"] = settings_dict["PASSWORD"]
213 if settings_dict["HOST"].startswith("/"):
214 kwargs["unix_socket"] = settings_dict["HOST"]
215 elif settings_dict["HOST"]:
216 kwargs["host"] = settings_dict["HOST"]
217 if settings_dict["PORT"]:
218 kwargs["port"] = int(settings_dict["PORT"])
219 # We need the number of potentially affected rows after an
220 # "UPDATE", not the number of changed rows.
221 kwargs["client_flag"] = CLIENT.FOUND_ROWS
222 # Validate the transaction isolation level, if specified.
223 options = settings_dict["OPTIONS"].copy()
224 isolation_level = options.pop("isolation_level", "read committed")
225 if isolation_level:
226 isolation_level = isolation_level.lower()
227 if isolation_level not in self.isolation_levels:
228 raise ImproperlyConfigured(
229 "Invalid transaction isolation level '{}' specified.\n"
230 "Use one of {}, or None.".format(
231 isolation_level,
232 ", ".join(f"'{s}'" for s in sorted(self.isolation_levels)),
233 )
234 )
235 self.isolation_level = isolation_level
236 kwargs.update(options)
237 return kwargs
238
239 def get_new_connection(self, conn_params: dict[str, Any]) -> MySQLConnection:
240 connection = Database.connect(**conn_params)
241 # bytes encoder in mysqlclient doesn't work and was added only to
242 # prevent KeyErrors in Plain < 2.0. We can remove this workaround when
243 # mysqlclient 2.1 becomes the minimal mysqlclient supported by Plain.
244 # See https://github.com/PyMySQL/mysqlclient/issues/489
245 if connection.encoders.get(bytes) is bytes:
246 connection.encoders.pop(bytes)
247 return connection
248
249 def init_connection_state(self) -> None:
250 super().init_connection_state()
251 assignments = []
252 if self.features.is_sql_auto_is_null_enabled:
253 # SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on
254 # a recently inserted row will return when the field is tested
255 # for NULL. Disabling this brings this aspect of MySQL in line
256 # with SQL standards.
257 assignments.append("SET SQL_AUTO_IS_NULL = 0")
258
259 if self.isolation_level:
260 assignments.append(
261 f"SET SESSION TRANSACTION ISOLATION LEVEL {self.isolation_level.upper()}"
262 )
263
264 if assignments:
265 with self.cursor() as cursor:
266 cursor.execute("; ".join(assignments))
267
268 def create_cursor(self, name: str | None = None) -> CursorWrapper:
269 cursor = self.connection.cursor()
270 return CursorWrapper(cursor)
271
272 def _rollback(self) -> None:
273 try:
274 BaseDatabaseWrapper._rollback(self)
275 except Database.NotSupportedError:
276 pass
277
278 def _set_autocommit(self, autocommit: bool) -> None:
279 with self.wrap_database_errors:
280 self.connection.autocommit(autocommit)
281
282 def check_constraints(self, table_names: list[str] | None = None) -> None:
283 """Check ``table_names`` for rows with invalid foreign key references."""
284 with self.cursor() as cursor:
285 if table_names is None:
286 table_names = self.introspection.table_names(cursor)
287 for table_name in table_names:
288 primary_key_column_name = self.introspection.get_primary_key_column(
289 cursor, table_name
290 )
291 if not primary_key_column_name:
292 continue
293 relations = self.introspection.get_relations(cursor, table_name)
294 for column_name, (
295 referenced_column_name,
296 referenced_table_name,
297 ) in relations.items():
298 cursor.execute(
299 f"""
300 SELECT REFERRING.`{primary_key_column_name}`, REFERRING.`{column_name}` FROM `{table_name}` as REFERRING
301 LEFT JOIN `{referenced_table_name}` as REFERRED
302 ON (REFERRING.`{column_name}` = REFERRED.`{referenced_column_name}`)
303 WHERE REFERRING.`{column_name}` IS NOT NULL AND REFERRED.`{referenced_column_name}` IS NULL
304 """
305 )
306 for bad_row in cursor.fetchall():
307 raise IntegrityError(
308 f"The row in table '{table_name}' with primary key '{bad_row[0]}' has an "
309 f"invalid foreign key: {table_name}.{column_name} contains a value '{bad_row[1]}' that "
310 f"does not have a corresponding value in {referenced_table_name}.{referenced_column_name}."
311 )
312
313 def is_usable(self) -> bool:
314 try:
315 self.connection.ping()
316 except Database.Error:
317 return False
318 else:
319 return True
320
321 @cached_property
322 def display_name(self) -> str:
323 return "MariaDB" if self.mysql_is_mariadb else "MySQL"
324
325 @cached_property
326 def data_type_check_constraints(self) -> dict[str, str]:
327 if self.features.supports_column_check_constraints:
328 check_constraints = {
329 "PositiveBigIntegerField": "`%(column)s` >= 0",
330 "PositiveIntegerField": "`%(column)s` >= 0",
331 "PositiveSmallIntegerField": "`%(column)s` >= 0",
332 }
333 if self.mysql_is_mariadb and self.mysql_version < (10, 4, 3):
334 # MariaDB < 10.4.3 doesn't automatically use the JSON_VALID as
335 # a check constraint.
336 check_constraints["JSONField"] = "JSON_VALID(`%(column)s`)"
337 return check_constraints
338 return {}
339
340 @cached_property
341 def mysql_server_data(self) -> dict[str, Any]:
342 with self.temporary_connection() as cursor:
343 # Select some server variables and test if the time zone
344 # definitions are installed. CONVERT_TZ returns NULL if 'UTC'
345 # timezone isn't loaded into the mysql.time_zone table.
346 cursor.execute(
347 """
348 SELECT VERSION(),
349 @@sql_mode,
350 @@default_storage_engine,
351 @@sql_auto_is_null,
352 @@lower_case_table_names,
353 CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL
354 """
355 )
356 row = cursor.fetchone()
357 return {
358 "version": row[0],
359 "sql_mode": row[1],
360 "default_storage_engine": row[2],
361 "sql_auto_is_null": bool(row[3]),
362 "lower_case_table_names": bool(row[4]),
363 "has_zoneinfo_database": bool(row[5]),
364 }
365
366 @cached_property
367 def mysql_server_info(self) -> str:
368 return self.mysql_server_data["version"]
369
370 @cached_property
371 def mysql_version(self) -> tuple[int, int, int]:
372 match = server_version_re.match(self.mysql_server_info)
373 if not match:
374 raise Exception(
375 f"Unable to determine MySQL version from version string {self.mysql_server_info!r}"
376 )
377 return tuple(int(x) for x in match.groups())
378
379 @cached_property
380 def mysql_is_mariadb(self) -> bool:
381 return "mariadb" in self.mysql_server_info.lower()
382
383 @cached_property
384 def sql_mode(self) -> set[str]:
385 sql_mode = self.mysql_server_data["sql_mode"]
386 return set(sql_mode.split(",") if sql_mode else ())