1"""
2MySQL database backend for Plain.
3
4Requires mysqlclient: https://pypi.org/project/mysqlclient/
5"""
6
7from __future__ import annotations
8
9from functools import cached_property
10from typing import TYPE_CHECKING, Any, cast
11
12import MySQLdb as Database
13from MySQLdb.constants import CLIENT, FIELD_TYPE
14from MySQLdb.converters import conversions
15
16from plain.exceptions import ImproperlyConfigured
17from plain.models.backends import utils as backend_utils
18from plain.models.backends.base.base import BaseDatabaseWrapper
19from plain.models.db import IntegrityError
20from plain.utils.regex_helper import _lazy_re_compile
21
22# Type checkers always see the proper type; runtime falls back to Any if needed
23if TYPE_CHECKING:
24 from MySQLdb.connections import Connection as MySQLConnection
25else:
26 try:
27 from MySQLdb.connections import Connection as MySQLConnection
28 except ImportError:
29 MySQLConnection = Any # type: ignore[misc]
30
31from .client import DatabaseClient
32from .creation import DatabaseCreation
33from .features import DatabaseFeatures
34from .introspection import DatabaseIntrospection
35from .operations import DatabaseOperations
36from .schema import DatabaseSchemaEditor
37from .validation import DatabaseValidation
38
39# MySQLdb returns TIME columns as timedelta -- they are more like timedelta in
40# terms of actual behavior as they are signed and include days -- and Plain
41# expects time.
42plain_conversions = {
43 **conversions,
44 **{FIELD_TYPE.TIME: backend_utils.typecast_time},
45}
46
47# This should match the numerical portion of the version numbers (we can treat
48# versions like 5.0.24 and 5.0.24a as the same).
49server_version_re = _lazy_re_compile(r"(\d{1,2})\.(\d{1,2})\.(\d{1,2})")
50
51
52class CursorWrapper:
53 """
54 A thin wrapper around MySQLdb's normal cursor class that catches particular
55 exception instances and reraises them with the correct types.
56
57 Implemented as a wrapper, rather than a subclass, so that it isn't stuck
58 to the particular underlying representation returned by Connection.cursor().
59 """
60
61 codes_for_integrityerror = (
62 1048, # Column cannot be null
63 1690, # BIGINT UNSIGNED value is out of range
64 3819, # CHECK constraint is violated
65 4025, # CHECK constraint failed
66 )
67
68 def __init__(self, cursor: Any) -> None:
69 self.cursor = cursor
70
71 def execute(self, query: str, args: Any = None) -> int:
72 try:
73 # args is None means no string interpolation
74 return self.cursor.execute(query, args)
75 except Database.OperationalError as e:
76 # Map some error codes to IntegrityError, since they seem to be
77 # misclassified and Plain would prefer the more logical place.
78 if e.args[0] in self.codes_for_integrityerror:
79 raise IntegrityError(*tuple(e.args))
80 raise
81
82 def executemany(self, query: str, args: Any) -> int:
83 try:
84 return self.cursor.executemany(query, args)
85 except Database.OperationalError as e:
86 # Map some error codes to IntegrityError, since they seem to be
87 # misclassified and Plain would prefer the more logical place.
88 if e.args[0] in self.codes_for_integrityerror:
89 raise IntegrityError(*tuple(e.args))
90 raise
91
92 def __getattr__(self, attr: str) -> Any:
93 return getattr(self.cursor, attr)
94
95 def __iter__(self) -> Any:
96 return iter(self.cursor)
97
98
99class MySQLDatabaseWrapper(BaseDatabaseWrapper):
100 # Type checker hints: narrow base class attribute types to backend-specific classes
101 ops: DatabaseOperations
102 features: DatabaseFeatures
103 introspection: DatabaseIntrospection
104 creation: DatabaseCreation
105
106 vendor = "mysql"
107 # This dictionary maps Field objects to their associated MySQL column
108 # types, as strings. Column-type strings can contain format strings; they'll
109 # be interpolated against the values of Field.__dict__ before being output.
110 # If a column type is set to None, it won't be included in the output.
111 data_types = {
112 "PrimaryKeyField": "bigint AUTO_INCREMENT",
113 "BinaryField": "longblob",
114 "BooleanField": "bool",
115 "CharField": "varchar(%(max_length)s)",
116 "DateField": "date",
117 "DateTimeField": "datetime(6)",
118 "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
119 "DurationField": "bigint",
120 "FloatField": "double precision",
121 "IntegerField": "integer",
122 "BigIntegerField": "bigint",
123 "GenericIPAddressField": "char(39)",
124 "JSONField": "json",
125 "PositiveBigIntegerField": "bigint UNSIGNED",
126 "PositiveIntegerField": "integer UNSIGNED",
127 "PositiveSmallIntegerField": "smallint UNSIGNED",
128 "SmallIntegerField": "smallint",
129 "TextField": "longtext",
130 "TimeField": "time(6)",
131 "UUIDField": "char(32)",
132 }
133
134 # For these data types:
135 # - MySQL < 8.0.13 doesn't accept default values and implicitly treats them
136 # as nullable
137 # - all versions of MySQL and MariaDB don't support full width database
138 # indexes
139 _limited_data_types = (
140 "tinyblob",
141 "blob",
142 "mediumblob",
143 "longblob",
144 "tinytext",
145 "text",
146 "mediumtext",
147 "longtext",
148 "json",
149 )
150
151 operators = {
152 "exact": "= %s",
153 "iexact": "LIKE %s",
154 "contains": "LIKE BINARY %s",
155 "icontains": "LIKE %s",
156 "gt": "> %s",
157 "gte": ">= %s",
158 "lt": "< %s",
159 "lte": "<= %s",
160 "startswith": "LIKE BINARY %s",
161 "endswith": "LIKE BINARY %s",
162 "istartswith": "LIKE %s",
163 "iendswith": "LIKE %s",
164 }
165
166 # The patterns below are used to generate SQL pattern lookup clauses when
167 # the right-hand side of the lookup isn't a raw string (it might be an expression
168 # or the result of a bilateral transformation).
169 # In those cases, special characters for LIKE operators (e.g. \, *, _) should be
170 # escaped on database side.
171 #
172 # Note: we use str.format() here for readability as '%' is used as a wildcard for
173 # the LIKE operator.
174 pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')"
175 pattern_ops = {
176 "contains": "LIKE BINARY CONCAT('%%', {}, '%%')",
177 "icontains": "LIKE CONCAT('%%', {}, '%%')",
178 "startswith": "LIKE BINARY CONCAT({}, '%%')",
179 "istartswith": "LIKE CONCAT({}, '%%')",
180 "endswith": "LIKE BINARY CONCAT('%%', {})",
181 "iendswith": "LIKE CONCAT('%%', {})",
182 }
183
184 isolation_levels = {
185 "read uncommitted",
186 "read committed",
187 "repeatable read",
188 "serializable",
189 }
190
191 Database = Database
192 SchemaEditorClass = DatabaseSchemaEditor
193 # Classes instantiated in __init__().
194 client_class = DatabaseClient
195 creation_class = DatabaseCreation
196 features_class = DatabaseFeatures
197 introspection_class = DatabaseIntrospection
198 ops_class = DatabaseOperations
199 validation_class = DatabaseValidation
200
201 def get_database_version(self) -> tuple[int, int, int]:
202 return self.mysql_version
203
204 def get_connection_params(self) -> dict[str, Any]:
205 kwargs: dict[str, Any] = {
206 "conv": plain_conversions,
207 "charset": "utf8",
208 }
209 settings_dict = self.settings_dict
210 if settings_dict["USER"]:
211 kwargs["user"] = settings_dict["USER"]
212 if settings_dict["NAME"]:
213 kwargs["database"] = settings_dict["NAME"]
214 if settings_dict["PASSWORD"]:
215 kwargs["password"] = settings_dict["PASSWORD"]
216 if settings_dict["HOST"].startswith("/"):
217 kwargs["unix_socket"] = settings_dict["HOST"]
218 elif settings_dict["HOST"]:
219 kwargs["host"] = settings_dict["HOST"]
220 if settings_dict["PORT"]:
221 kwargs["port"] = int(settings_dict["PORT"])
222 # We need the number of potentially affected rows after an
223 # "UPDATE", not the number of changed rows.
224 kwargs["client_flag"] = CLIENT.FOUND_ROWS
225 # Validate the transaction isolation level, if specified.
226 options = settings_dict["OPTIONS"].copy()
227 isolation_level = options.pop("isolation_level", "read committed")
228 if isolation_level:
229 isolation_level = isolation_level.lower()
230 if isolation_level not in self.isolation_levels:
231 raise ImproperlyConfigured(
232 "Invalid transaction isolation level '{}' specified.\n"
233 "Use one of {}, or None.".format(
234 isolation_level,
235 ", ".join(f"'{s}'" for s in sorted(self.isolation_levels)),
236 )
237 )
238 self.isolation_level = isolation_level
239 kwargs.update(options)
240 return kwargs
241
242 def get_new_connection(self, conn_params: dict[str, Any]) -> MySQLConnection:
243 connection = Database.connect(**conn_params)
244 # bytes encoder in mysqlclient doesn't work and was added only to
245 # prevent KeyErrors in Plain < 2.0. We can remove this workaround when
246 # mysqlclient 2.1 becomes the minimal mysqlclient supported by Plain.
247 # See https://github.com/PyMySQL/mysqlclient/issues/489
248 if connection.encoders.get(bytes) is bytes:
249 connection.encoders.pop(bytes)
250 return connection
251
252 def init_connection_state(self) -> None:
253 super().init_connection_state()
254 assignments = []
255 if self.features.is_sql_auto_is_null_enabled:
256 # SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on
257 # a recently inserted row will return when the field is tested
258 # for NULL. Disabling this brings this aspect of MySQL in line
259 # with SQL standards.
260 assignments.append("SET SQL_AUTO_IS_NULL = 0")
261
262 if self.isolation_level:
263 assignments.append(
264 f"SET SESSION TRANSACTION ISOLATION LEVEL {self.isolation_level.upper()}"
265 )
266
267 if assignments:
268 with self.cursor() as cursor:
269 cursor.execute("; ".join(assignments))
270
271 def create_cursor(self, name: str | None = None) -> CursorWrapper:
272 cursor = self.connection.cursor()
273 return CursorWrapper(cursor)
274
275 def _rollback(self) -> None:
276 try:
277 BaseDatabaseWrapper._rollback(self)
278 except Database.NotSupportedError:
279 pass
280
281 def _set_autocommit(self, autocommit: bool) -> None:
282 with self.wrap_database_errors:
283 self.connection.autocommit(autocommit)
284
285 def check_constraints(self, table_names: list[str] | None = None) -> None:
286 """Check ``table_names`` for rows with invalid foreign key references."""
287 with self.cursor() as cursor:
288 if table_names is None:
289 table_names = self.introspection.table_names(cursor)
290 for table_name in table_names:
291 primary_key_column_name = self.introspection.get_primary_key_column(
292 cursor, table_name
293 )
294 if not primary_key_column_name:
295 continue
296 relations = self.introspection.get_relations(cursor, table_name)
297 for column_name, (
298 referenced_column_name,
299 referenced_table_name,
300 ) in relations.items():
301 cursor.execute(
302 f"""
303 SELECT REFERRING.`{primary_key_column_name}`, REFERRING.`{column_name}` FROM `{table_name}` as REFERRING
304 LEFT JOIN `{referenced_table_name}` as REFERRED
305 ON (REFERRING.`{column_name}` = REFERRED.`{referenced_column_name}`)
306 WHERE REFERRING.`{column_name}` IS NOT NULL AND REFERRED.`{referenced_column_name}` IS NULL
307 """
308 )
309 for bad_row in cursor.fetchall():
310 raise IntegrityError(
311 f"The row in table '{table_name}' with primary key '{bad_row[0]}' has an "
312 f"invalid foreign key: {table_name}.{column_name} contains a value '{bad_row[1]}' that "
313 f"does not have a corresponding value in {referenced_table_name}.{referenced_column_name}."
314 )
315
316 def is_usable(self) -> bool:
317 try:
318 self.connection.ping()
319 except Database.Error:
320 return False
321 else:
322 return True
323
324 @cached_property
325 def display_name(self) -> str:
326 return "MariaDB" if self.mysql_is_mariadb else "MySQL"
327
328 @cached_property
329 def data_type_check_constraints(self) -> dict[str, str]:
330 if self.features.supports_column_check_constraints:
331 check_constraints = {
332 "PositiveBigIntegerField": "`%(column)s` >= 0",
333 "PositiveIntegerField": "`%(column)s` >= 0",
334 "PositiveSmallIntegerField": "`%(column)s` >= 0",
335 }
336 if self.mysql_is_mariadb and self.mysql_version < (10, 4, 3):
337 # MariaDB < 10.4.3 doesn't automatically use the JSON_VALID as
338 # a check constraint.
339 check_constraints["JSONField"] = "JSON_VALID(`%(column)s`)"
340 return check_constraints
341 return {}
342
343 @cached_property
344 def mysql_server_data(self) -> dict[str, Any]:
345 with self.temporary_connection() as cursor:
346 # Select some server variables and test if the time zone
347 # definitions are installed. CONVERT_TZ returns NULL if 'UTC'
348 # timezone isn't loaded into the mysql.time_zone table.
349 cursor.execute(
350 """
351 SELECT VERSION(),
352 @@sql_mode,
353 @@default_storage_engine,
354 @@sql_auto_is_null,
355 @@lower_case_table_names,
356 CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL
357 """
358 )
359 row = cursor.fetchone()
360 return {
361 "version": row[0],
362 "sql_mode": row[1],
363 "default_storage_engine": row[2],
364 "sql_auto_is_null": bool(row[3]),
365 "lower_case_table_names": bool(row[4]),
366 "has_zoneinfo_database": bool(row[5]),
367 }
368
369 @cached_property
370 def mysql_server_info(self) -> str:
371 return self.mysql_server_data["version"]
372
373 @cached_property
374 def mysql_version(self) -> tuple[int, int, int]:
375 match = server_version_re.match(self.mysql_server_info)
376 if not match:
377 raise Exception(
378 f"Unable to determine MySQL version from version string {self.mysql_server_info!r}"
379 )
380 return cast(tuple[int, int, int], tuple(int(x) for x in match.groups()))
381
382 @cached_property
383 def mysql_is_mariadb(self) -> bool:
384 return "mariadb" in self.mysql_server_info.lower()
385
386 @cached_property
387 def sql_mode(self) -> set[str]:
388 sql_mode = self.mysql_server_data["sql_mode"]
389 return set(sql_mode.split(",") if sql_mode else ())