1"""
2MySQL database backend for Plain.
3
4Requires mysqlclient: https://pypi.org/project/mysqlclient/
5"""
6
7from functools import cached_property
8
9import MySQLdb as Database
10from MySQLdb.constants import CLIENT, FIELD_TYPE
11from MySQLdb.converters import conversions
12
13from plain.exceptions import ImproperlyConfigured
14from plain.models.backends import utils as backend_utils
15from plain.models.backends.base.base import BaseDatabaseWrapper
16from plain.models.db import IntegrityError
17from plain.utils.regex_helper import _lazy_re_compile
18
19from .client import DatabaseClient
20from .creation import DatabaseCreation
21from .features import DatabaseFeatures
22from .introspection import DatabaseIntrospection
23from .operations import DatabaseOperations
24from .schema import DatabaseSchemaEditor
25from .validation import DatabaseValidation
26
27# MySQLdb returns TIME columns as timedelta -- they are more like timedelta in
28# terms of actual behavior as they are signed and include days -- and Plain
29# expects time.
30plain_conversions = {
31 **conversions,
32 **{FIELD_TYPE.TIME: backend_utils.typecast_time},
33}
34
35# This should match the numerical portion of the version numbers (we can treat
36# versions like 5.0.24 and 5.0.24a as the same).
37server_version_re = _lazy_re_compile(r"(\d{1,2})\.(\d{1,2})\.(\d{1,2})")
38
39
40class CursorWrapper:
41 """
42 A thin wrapper around MySQLdb's normal cursor class that catches particular
43 exception instances and reraises them with the correct types.
44
45 Implemented as a wrapper, rather than a subclass, so that it isn't stuck
46 to the particular underlying representation returned by Connection.cursor().
47 """
48
49 codes_for_integrityerror = (
50 1048, # Column cannot be null
51 1690, # BIGINT UNSIGNED value is out of range
52 3819, # CHECK constraint is violated
53 4025, # CHECK constraint failed
54 )
55
56 def __init__(self, cursor):
57 self.cursor = cursor
58
59 def execute(self, query, args=None):
60 try:
61 # args is None means no string interpolation
62 return self.cursor.execute(query, args)
63 except Database.OperationalError as e:
64 # Map some error codes to IntegrityError, since they seem to be
65 # misclassified and Plain would prefer the more logical place.
66 if e.args[0] in self.codes_for_integrityerror:
67 raise IntegrityError(*tuple(e.args))
68 raise
69
70 def executemany(self, query, args):
71 try:
72 return self.cursor.executemany(query, args)
73 except Database.OperationalError as e:
74 # Map some error codes to IntegrityError, since they seem to be
75 # misclassified and Plain would prefer the more logical place.
76 if e.args[0] in self.codes_for_integrityerror:
77 raise IntegrityError(*tuple(e.args))
78 raise
79
80 def __getattr__(self, attr):
81 return getattr(self.cursor, attr)
82
83 def __iter__(self):
84 return iter(self.cursor)
85
86
87class DatabaseWrapper(BaseDatabaseWrapper):
88 vendor = "mysql"
89 # This dictionary maps Field objects to their associated MySQL column
90 # types, as strings. Column-type strings can contain format strings; they'll
91 # be interpolated against the values of Field.__dict__ before being output.
92 # If a column type is set to None, it won't be included in the output.
93 data_types = {
94 "AutoField": "integer AUTO_INCREMENT",
95 "BigAutoField": "bigint AUTO_INCREMENT",
96 "BinaryField": "longblob",
97 "BooleanField": "bool",
98 "CharField": "varchar(%(max_length)s)",
99 "DateField": "date",
100 "DateTimeField": "datetime(6)",
101 "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
102 "DurationField": "bigint",
103 "FloatField": "double precision",
104 "IntegerField": "integer",
105 "BigIntegerField": "bigint",
106 "IPAddressField": "char(15)",
107 "GenericIPAddressField": "char(39)",
108 "JSONField": "json",
109 "PositiveBigIntegerField": "bigint UNSIGNED",
110 "PositiveIntegerField": "integer UNSIGNED",
111 "PositiveSmallIntegerField": "smallint UNSIGNED",
112 "SmallAutoField": "smallint AUTO_INCREMENT",
113 "SmallIntegerField": "smallint",
114 "TextField": "longtext",
115 "TimeField": "time(6)",
116 "UUIDField": "char(32)",
117 }
118
119 # For these data types:
120 # - MySQL < 8.0.13 doesn't accept default values and implicitly treats them
121 # as nullable
122 # - all versions of MySQL and MariaDB don't support full width database
123 # indexes
124 _limited_data_types = (
125 "tinyblob",
126 "blob",
127 "mediumblob",
128 "longblob",
129 "tinytext",
130 "text",
131 "mediumtext",
132 "longtext",
133 "json",
134 )
135
136 operators = {
137 "exact": "= %s",
138 "iexact": "LIKE %s",
139 "contains": "LIKE BINARY %s",
140 "icontains": "LIKE %s",
141 "gt": "> %s",
142 "gte": ">= %s",
143 "lt": "< %s",
144 "lte": "<= %s",
145 "startswith": "LIKE BINARY %s",
146 "endswith": "LIKE BINARY %s",
147 "istartswith": "LIKE %s",
148 "iendswith": "LIKE %s",
149 }
150
151 # The patterns below are used to generate SQL pattern lookup clauses when
152 # the right-hand side of the lookup isn't a raw string (it might be an expression
153 # or the result of a bilateral transformation).
154 # In those cases, special characters for LIKE operators (e.g. \, *, _) should be
155 # escaped on database side.
156 #
157 # Note: we use str.format() here for readability as '%' is used as a wildcard for
158 # the LIKE operator.
159 pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')"
160 pattern_ops = {
161 "contains": "LIKE BINARY CONCAT('%%', {}, '%%')",
162 "icontains": "LIKE CONCAT('%%', {}, '%%')",
163 "startswith": "LIKE BINARY CONCAT({}, '%%')",
164 "istartswith": "LIKE CONCAT({}, '%%')",
165 "endswith": "LIKE BINARY CONCAT('%%', {})",
166 "iendswith": "LIKE CONCAT('%%', {})",
167 }
168
169 isolation_levels = {
170 "read uncommitted",
171 "read committed",
172 "repeatable read",
173 "serializable",
174 }
175
176 Database = Database
177 SchemaEditorClass = DatabaseSchemaEditor
178 # Classes instantiated in __init__().
179 client_class = DatabaseClient
180 creation_class = DatabaseCreation
181 features_class = DatabaseFeatures
182 introspection_class = DatabaseIntrospection
183 ops_class = DatabaseOperations
184 validation_class = DatabaseValidation
185
186 def get_database_version(self):
187 return self.mysql_version
188
189 def get_connection_params(self):
190 kwargs = {
191 "conv": plain_conversions,
192 "charset": "utf8",
193 }
194 settings_dict = self.settings_dict
195 if settings_dict["USER"]:
196 kwargs["user"] = settings_dict["USER"]
197 if settings_dict["NAME"]:
198 kwargs["database"] = settings_dict["NAME"]
199 if settings_dict["PASSWORD"]:
200 kwargs["password"] = settings_dict["PASSWORD"]
201 if settings_dict["HOST"].startswith("/"):
202 kwargs["unix_socket"] = settings_dict["HOST"]
203 elif settings_dict["HOST"]:
204 kwargs["host"] = settings_dict["HOST"]
205 if settings_dict["PORT"]:
206 kwargs["port"] = int(settings_dict["PORT"])
207 # We need the number of potentially affected rows after an
208 # "UPDATE", not the number of changed rows.
209 kwargs["client_flag"] = CLIENT.FOUND_ROWS
210 # Validate the transaction isolation level, if specified.
211 options = settings_dict["OPTIONS"].copy()
212 isolation_level = options.pop("isolation_level", "read committed")
213 if isolation_level:
214 isolation_level = isolation_level.lower()
215 if isolation_level not in self.isolation_levels:
216 raise ImproperlyConfigured(
217 "Invalid transaction isolation level '{}' specified.\n"
218 "Use one of {}, or None.".format(
219 isolation_level,
220 ", ".join(f"'{s}'" for s in sorted(self.isolation_levels)),
221 )
222 )
223 self.isolation_level = isolation_level
224 kwargs.update(options)
225 return kwargs
226
227 def get_new_connection(self, conn_params):
228 connection = Database.connect(**conn_params)
229 # bytes encoder in mysqlclient doesn't work and was added only to
230 # prevent KeyErrors in Plain < 2.0. We can remove this workaround when
231 # mysqlclient 2.1 becomes the minimal mysqlclient supported by Plain.
232 # See https://github.com/PyMySQL/mysqlclient/issues/489
233 if connection.encoders.get(bytes) is bytes:
234 connection.encoders.pop(bytes)
235 return connection
236
237 def init_connection_state(self):
238 super().init_connection_state()
239 assignments = []
240 if self.features.is_sql_auto_is_null_enabled:
241 # SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on
242 # a recently inserted row will return when the field is tested
243 # for NULL. Disabling this brings this aspect of MySQL in line
244 # with SQL standards.
245 assignments.append("SET SQL_AUTO_IS_NULL = 0")
246
247 if self.isolation_level:
248 assignments.append(
249 f"SET SESSION TRANSACTION ISOLATION LEVEL {self.isolation_level.upper()}"
250 )
251
252 if assignments:
253 with self.cursor() as cursor:
254 cursor.execute("; ".join(assignments))
255
256 def create_cursor(self, name=None):
257 cursor = self.connection.cursor()
258 return CursorWrapper(cursor)
259
260 def _rollback(self):
261 try:
262 BaseDatabaseWrapper._rollback(self)
263 except Database.NotSupportedError:
264 pass
265
266 def _set_autocommit(self, autocommit):
267 with self.wrap_database_errors:
268 self.connection.autocommit(autocommit)
269
270 def check_constraints(self, table_names=None):
271 """Check ``table_names`` for rows with invalid foreign key references."""
272 with self.cursor() as cursor:
273 if table_names is None:
274 table_names = self.introspection.table_names(cursor)
275 for table_name in table_names:
276 primary_key_column_name = self.introspection.get_primary_key_column(
277 cursor, table_name
278 )
279 if not primary_key_column_name:
280 continue
281 relations = self.introspection.get_relations(cursor, table_name)
282 for column_name, (
283 referenced_column_name,
284 referenced_table_name,
285 ) in relations.items():
286 cursor.execute(
287 f"""
288 SELECT REFERRING.`{primary_key_column_name}`, REFERRING.`{column_name}` FROM `{table_name}` as REFERRING
289 LEFT JOIN `{referenced_table_name}` as REFERRED
290 ON (REFERRING.`{column_name}` = REFERRED.`{referenced_column_name}`)
291 WHERE REFERRING.`{column_name}` IS NOT NULL AND REFERRED.`{referenced_column_name}` IS NULL
292 """
293 )
294 for bad_row in cursor.fetchall():
295 raise IntegrityError(
296 f"The row in table '{table_name}' with primary key '{bad_row[0]}' has an "
297 f"invalid foreign key: {table_name}.{column_name} contains a value '{bad_row[1]}' that "
298 f"does not have a corresponding value in {referenced_table_name}.{referenced_column_name}."
299 )
300
301 def is_usable(self):
302 try:
303 self.connection.ping()
304 except Database.Error:
305 return False
306 else:
307 return True
308
309 @cached_property
310 def display_name(self):
311 return "MariaDB" if self.mysql_is_mariadb else "MySQL"
312
313 @cached_property
314 def data_type_check_constraints(self):
315 if self.features.supports_column_check_constraints:
316 check_constraints = {
317 "PositiveBigIntegerField": "`%(column)s` >= 0",
318 "PositiveIntegerField": "`%(column)s` >= 0",
319 "PositiveSmallIntegerField": "`%(column)s` >= 0",
320 }
321 if self.mysql_is_mariadb and self.mysql_version < (10, 4, 3):
322 # MariaDB < 10.4.3 doesn't automatically use the JSON_VALID as
323 # a check constraint.
324 check_constraints["JSONField"] = "JSON_VALID(`%(column)s`)"
325 return check_constraints
326 return {}
327
328 @cached_property
329 def mysql_server_data(self):
330 with self.temporary_connection() as cursor:
331 # Select some server variables and test if the time zone
332 # definitions are installed. CONVERT_TZ returns NULL if 'UTC'
333 # timezone isn't loaded into the mysql.time_zone table.
334 cursor.execute(
335 """
336 SELECT VERSION(),
337 @@sql_mode,
338 @@default_storage_engine,
339 @@sql_auto_is_null,
340 @@lower_case_table_names,
341 CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL
342 """
343 )
344 row = cursor.fetchone()
345 return {
346 "version": row[0],
347 "sql_mode": row[1],
348 "default_storage_engine": row[2],
349 "sql_auto_is_null": bool(row[3]),
350 "lower_case_table_names": bool(row[4]),
351 "has_zoneinfo_database": bool(row[5]),
352 }
353
354 @cached_property
355 def mysql_server_info(self):
356 return self.mysql_server_data["version"]
357
358 @cached_property
359 def mysql_version(self):
360 match = server_version_re.match(self.mysql_server_info)
361 if not match:
362 raise Exception(
363 f"Unable to determine MySQL version from version string {self.mysql_server_info!r}"
364 )
365 return tuple(int(x) for x in match.groups())
366
367 @cached_property
368 def mysql_is_mariadb(self):
369 return "mariadb" in self.mysql_server_info.lower()
370
371 @cached_property
372 def sql_mode(self):
373 sql_mode = self.mysql_server_data["sql_mode"]
374 return set(sql_mode.split(",") if sql_mode else ())