1"""
2MySQL database backend for Plain.
3
4Requires mysqlclient: https://pypi.org/project/mysqlclient/
5"""
6
7from __future__ import annotations
8
9from functools import cached_property
10from typing import TYPE_CHECKING, Any, cast
11
12import MySQLdb as Database
13from MySQLdb.constants import CLIENT, FIELD_TYPE
14from MySQLdb.converters import conversions
15
16from plain.exceptions import ImproperlyConfigured
17from plain.models.backends import utils as backend_utils
18from plain.models.backends.base.base import BaseDatabaseWrapper
19from plain.models.db import IntegrityError
20from plain.utils.regex_helper import _lazy_re_compile
21
22# Type checkers always see the proper type; runtime falls back to Any if needed
23if TYPE_CHECKING:
24 from MySQLdb.connections import Connection as MySQLConnection
25else:
26 try:
27 from MySQLdb.connections import Connection as MySQLConnection
28 except ImportError:
29 MySQLConnection = Any
30
31from .client import DatabaseClient
32from .creation import DatabaseCreation
33from .features import DatabaseFeatures
34from .introspection import DatabaseIntrospection
35from .operations import DatabaseOperations
36from .schema import DatabaseSchemaEditor
37from .validation import DatabaseValidation
38
39# MySQLdb returns TIME columns as timedelta -- they are more like timedelta in
40# terms of actual behavior as they are signed and include days -- and Plain
41# expects time.
42plain_conversions = {
43 **conversions,
44 **{FIELD_TYPE.TIME: backend_utils.typecast_time},
45}
46
47# This should match the numerical portion of the version numbers (we can treat
48# versions like 5.0.24 and 5.0.24a as the same).
49server_version_re = _lazy_re_compile(r"(\d{1,2})\.(\d{1,2})\.(\d{1,2})")
50
51
52class CursorWrapper:
53 """
54 A thin wrapper around MySQLdb's normal cursor class that catches particular
55 exception instances and reraises them with the correct types.
56
57 Implemented as a wrapper, rather than a subclass, so that it isn't stuck
58 to the particular underlying representation returned by Connection.cursor().
59 """
60
61 codes_for_integrityerror = (
62 1048, # Column cannot be null
63 1690, # BIGINT UNSIGNED value is out of range
64 3819, # CHECK constraint is violated
65 4025, # CHECK constraint failed
66 )
67
68 def __init__(self, cursor: Any) -> None:
69 self.cursor = cursor
70
71 def execute(self, query: str, args: Any = None) -> int:
72 try:
73 # args is None means no string interpolation
74 return self.cursor.execute(query, args)
75 except Database.OperationalError as e:
76 # Map some error codes to IntegrityError, since they seem to be
77 # misclassified and Plain would prefer the more logical place.
78 if e.args[0] in self.codes_for_integrityerror:
79 raise IntegrityError(*tuple(e.args))
80 raise
81
82 def executemany(self, query: str, args: Any) -> int:
83 try:
84 return self.cursor.executemany(query, args)
85 except Database.OperationalError as e:
86 # Map some error codes to IntegrityError, since they seem to be
87 # misclassified and Plain would prefer the more logical place.
88 if e.args[0] in self.codes_for_integrityerror:
89 raise IntegrityError(*tuple(e.args))
90 raise
91
92 def __getattr__(self, attr: str) -> Any:
93 return getattr(self.cursor, attr)
94
95 def __iter__(self) -> Any:
96 return iter(self.cursor)
97
98
99class MySQLDatabaseWrapper(BaseDatabaseWrapper):
100 # Type checker hints: narrow base class attribute types to backend-specific classes
101 ops: DatabaseOperations
102 features: DatabaseFeatures
103 introspection: DatabaseIntrospection
104 creation: DatabaseCreation
105
106 vendor = "mysql"
107 # This dictionary maps Field objects to their associated MySQL column
108 # types, as strings. Column-type strings can contain format strings; they'll
109 # be interpolated against the values of Field.__dict__ before being output.
110 # If a column type is set to None, it won't be included in the output.
111 data_types = {
112 "PrimaryKeyField": "bigint AUTO_INCREMENT",
113 "BinaryField": "longblob",
114 "BooleanField": "bool",
115 "CharField": "varchar(%(max_length)s)",
116 "DateField": "date",
117 "DateTimeField": "datetime(6)",
118 "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
119 "DurationField": "bigint",
120 "FloatField": "double precision",
121 "IntegerField": "integer",
122 "BigIntegerField": "bigint",
123 "GenericIPAddressField": "char(39)",
124 "JSONField": "json",
125 "PositiveBigIntegerField": "bigint UNSIGNED",
126 "PositiveIntegerField": "integer UNSIGNED",
127 "PositiveSmallIntegerField": "smallint UNSIGNED",
128 "SmallIntegerField": "smallint",
129 "TextField": "longtext",
130 "TimeField": "time(6)",
131 "UUIDField": "char(32)",
132 }
133
134 # For these data types:
135 # - MySQL < 8.0.13 doesn't accept default values and implicitly treats them
136 # as nullable
137 # - all versions of MySQL and MariaDB don't support full width database
138 # indexes
139 _limited_data_types = (
140 "tinyblob",
141 "blob",
142 "mediumblob",
143 "longblob",
144 "tinytext",
145 "text",
146 "mediumtext",
147 "longtext",
148 "json",
149 )
150
151 operators = {
152 "exact": "= %s",
153 "iexact": "LIKE %s",
154 "contains": "LIKE BINARY %s",
155 "icontains": "LIKE %s",
156 "gt": "> %s",
157 "gte": ">= %s",
158 "lt": "< %s",
159 "lte": "<= %s",
160 "startswith": "LIKE BINARY %s",
161 "endswith": "LIKE BINARY %s",
162 "istartswith": "LIKE %s",
163 "iendswith": "LIKE %s",
164 }
165
166 # The patterns below are used to generate SQL pattern lookup clauses when
167 # the right-hand side of the lookup isn't a raw string (it might be an expression
168 # or the result of a bilateral transformation).
169 # In those cases, special characters for LIKE operators (e.g. \, *, _) should be
170 # escaped on database side.
171 #
172 # Note: we use str.format() here for readability as '%' is used as a wildcard for
173 # the LIKE operator.
174 pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')"
175 pattern_ops = {
176 "contains": "LIKE BINARY CONCAT('%%', {}, '%%')",
177 "icontains": "LIKE CONCAT('%%', {}, '%%')",
178 "startswith": "LIKE BINARY CONCAT({}, '%%')",
179 "istartswith": "LIKE CONCAT({}, '%%')",
180 "endswith": "LIKE BINARY CONCAT('%%', {})",
181 "iendswith": "LIKE CONCAT('%%', {})",
182 }
183
184 isolation_levels = {
185 "read uncommitted",
186 "read committed",
187 "repeatable read",
188 "serializable",
189 }
190
191 Database = Database
192 SchemaEditorClass = DatabaseSchemaEditor
193 # Classes instantiated in __init__().
194 client_class = DatabaseClient
195 creation_class = DatabaseCreation
196 features_class = DatabaseFeatures
197 introspection_class = DatabaseIntrospection
198 ops_class = DatabaseOperations
199 validation_class = DatabaseValidation
200
201 def get_database_version(self) -> tuple[int, int, int]:
202 return self.mysql_version
203
204 def get_connection_params(self) -> dict[str, Any]:
205 kwargs: dict[str, Any] = {
206 "conv": plain_conversions,
207 "charset": "utf8",
208 }
209 settings_dict = self.settings_dict
210 if settings_dict["USER"]:
211 kwargs["user"] = settings_dict["USER"]
212 if settings_dict["NAME"]:
213 kwargs["database"] = settings_dict["NAME"]
214 if settings_dict["PASSWORD"]:
215 kwargs["password"] = settings_dict["PASSWORD"]
216 if settings_dict["HOST"].startswith("/"):
217 kwargs["unix_socket"] = settings_dict["HOST"]
218 elif settings_dict["HOST"]:
219 kwargs["host"] = settings_dict["HOST"]
220 if settings_dict["PORT"]:
221 kwargs["port"] = int(settings_dict["PORT"])
222 # We need the number of potentially affected rows after an
223 # "UPDATE", not the number of changed rows.
224 kwargs["client_flag"] = CLIENT.FOUND_ROWS
225 # Validate the transaction isolation level, if specified.
226 options = settings_dict.get("OPTIONS", {}).copy()
227 isolation_level = options.pop("isolation_level", "read committed")
228 if isolation_level:
229 isolation_level = isolation_level.lower()
230 if isolation_level not in self.isolation_levels:
231 raise ImproperlyConfigured(
232 "Invalid transaction isolation level '{}' specified.\n"
233 "Use one of {}, or None.".format(
234 isolation_level,
235 ", ".join(f"'{s}'" for s in sorted(self.isolation_levels)),
236 )
237 )
238 self.isolation_level = isolation_level
239 kwargs.update(options)
240 return kwargs
241
242 def get_new_connection(self, conn_params: dict[str, Any]) -> MySQLConnection:
243 connection = Database.connect(**conn_params)
244 # bytes encoder in mysqlclient doesn't work and was added only to
245 # prevent KeyErrors in Plain < 2.0. We can remove this workaround when
246 # mysqlclient 2.1 becomes the minimal mysqlclient supported by Plain.
247 # See https://github.com/PyMySQL/mysqlclient/issues/489
248 if connection.encoders.get(bytes) is bytes:
249 connection.encoders.pop(bytes)
250 return connection
251
252 def init_connection_state(self) -> None:
253 super().init_connection_state()
254 assignments = []
255 if self.features.is_sql_auto_is_null_enabled:
256 # SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on
257 # a recently inserted row will return when the field is tested
258 # for NULL. Disabling this brings this aspect of MySQL in line
259 # with SQL standards.
260 assignments.append("SET SQL_AUTO_IS_NULL = 0")
261
262 if self.isolation_level:
263 assignments.append(
264 f"SET SESSION TRANSACTION ISOLATION LEVEL {self.isolation_level.upper()}"
265 )
266
267 if assignments:
268 with self.cursor() as cursor:
269 cursor.execute("; ".join(assignments))
270
271 def create_cursor(self, name: str | None = None) -> CursorWrapper:
272 cursor = self.connection.cursor()
273 return CursorWrapper(cursor)
274
275 def _rollback(self) -> None:
276 try:
277 BaseDatabaseWrapper._rollback(self)
278 except Database.NotSupportedError:
279 pass
280
281 def _set_autocommit(self, autocommit: bool) -> None:
282 with self.wrap_database_errors:
283 self.connection.autocommit(autocommit)
284
285 def check_constraints(self, table_names: list[str] | None = None) -> None:
286 """Check ``table_names`` for rows with invalid foreign key references."""
287 with self.cursor() as cursor:
288 if table_names is None:
289 table_names = self.introspection.table_names(cursor)
290 for table_name in table_names:
291 primary_key_column_name = self.introspection.get_primary_key_column(
292 cursor, table_name
293 )
294 if not primary_key_column_name:
295 continue
296 relations = self.introspection.get_relations(cursor, table_name)
297 for column_name, (
298 referenced_column_name,
299 referenced_table_name,
300 ) in relations.items():
301 cursor.execute(
302 f"""
303 SELECT REFERRING.`{primary_key_column_name}`, REFERRING.`{column_name}` FROM `{table_name}` as REFERRING
304 LEFT JOIN `{referenced_table_name}` as REFERRED
305 ON (REFERRING.`{column_name}` = REFERRED.`{referenced_column_name}`)
306 WHERE REFERRING.`{column_name}` IS NOT NULL AND REFERRED.`{referenced_column_name}` IS NULL
307 """
308 )
309 for bad_row in cursor.fetchall():
310 raise IntegrityError(
311 f"The row in table '{table_name}' with primary key '{bad_row[0]}' has an "
312 f"invalid foreign key: {table_name}.{column_name} contains a value '{bad_row[1]}' that "
313 f"does not have a corresponding value in {referenced_table_name}.{referenced_column_name}."
314 )
315
316 def is_usable(self) -> bool:
317 try:
318 self.connection.ping()
319 except Database.Error:
320 return False
321 else:
322 return True
323
324 @cached_property
325 def display_name(self) -> str:
326 return "MariaDB" if self.mysql_is_mariadb else "MySQL"
327
328 @cached_property
329 def data_type_check_constraints(self) -> dict[str, str]:
330 if self.features.supports_column_check_constraints:
331 check_constraints = {
332 "PositiveBigIntegerField": "`%(column)s` >= 0",
333 "PositiveIntegerField": "`%(column)s` >= 0",
334 "PositiveSmallIntegerField": "`%(column)s` >= 0",
335 }
336 if self.mysql_is_mariadb and self.mysql_version < (10, 4, 3):
337 # MariaDB < 10.4.3 doesn't automatically use the JSON_VALID as
338 # a check constraint.
339 check_constraints["JSONField"] = "JSON_VALID(`%(column)s`)"
340 return check_constraints
341 return {}
342
343 @cached_property
344 def mysql_server_data(self) -> dict[str, Any]:
345 with self.temporary_connection() as cursor:
346 # Select some server variables and test if the time zone
347 # definitions are installed. CONVERT_TZ returns NULL if 'UTC'
348 # timezone isn't loaded into the mysql.time_zone table.
349 cursor.execute(
350 """
351 SELECT VERSION(),
352 @@sql_mode,
353 @@default_storage_engine,
354 @@sql_auto_is_null,
355 @@lower_case_table_names,
356 CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL
357 """
358 )
359 row = cursor.fetchone()
360 return {
361 "version": row[0],
362 "sql_mode": row[1],
363 "default_storage_engine": row[2],
364 "sql_auto_is_null": bool(row[3]),
365 "lower_case_table_names": bool(row[4]),
366 "has_zoneinfo_database": bool(row[5]),
367 }
368
369 @cached_property
370 def mysql_server_info(self) -> str:
371 return self.mysql_server_data["version"]
372
373 @cached_property
374 def mysql_version(self) -> tuple[int, int, int]:
375 match = server_version_re.match(self.mysql_server_info)
376 if not match:
377 raise Exception(
378 f"Unable to determine MySQL version from version string {self.mysql_server_info!r}"
379 )
380 return cast(tuple[int, int, int], tuple(int(x) for x in match.groups()))
381
382 @cached_property
383 def mysql_is_mariadb(self) -> bool:
384 return "mariadb" in self.mysql_server_info.lower()
385
386 @cached_property
387 def sql_mode(self) -> set[str]:
388 sql_mode = self.mysql_server_data["sql_mode"]
389 return set(sql_mode.split(",") if sql_mode else ())