1"""
2MySQL database backend for Plain.
3
4Requires mysqlclient: https://pypi.org/project/mysqlclient/
5"""
6
7from functools import cached_property
8
9import MySQLdb as Database
10from MySQLdb.constants import CLIENT, FIELD_TYPE
11from MySQLdb.converters import conversions
12
13from plain.exceptions import ImproperlyConfigured
14from plain.models.backends import utils as backend_utils
15from plain.models.backends.base.base import BaseDatabaseWrapper
16from plain.models.db import IntegrityError
17from plain.utils.regex_helper import _lazy_re_compile
18
19from .client import DatabaseClient
20from .creation import DatabaseCreation
21from .features import DatabaseFeatures
22from .introspection import DatabaseIntrospection
23from .operations import DatabaseOperations
24from .schema import DatabaseSchemaEditor
25from .validation import DatabaseValidation
26
27# MySQLdb returns TIME columns as timedelta -- they are more like timedelta in
28# terms of actual behavior as they are signed and include days -- and Plain
29# expects time.
30plain_conversions = {
31 **conversions,
32 **{FIELD_TYPE.TIME: backend_utils.typecast_time},
33}
34
35# This should match the numerical portion of the version numbers (we can treat
36# versions like 5.0.24 and 5.0.24a as the same).
37server_version_re = _lazy_re_compile(r"(\d{1,2})\.(\d{1,2})\.(\d{1,2})")
38
39
40class CursorWrapper:
41 """
42 A thin wrapper around MySQLdb's normal cursor class that catches particular
43 exception instances and reraises them with the correct types.
44
45 Implemented as a wrapper, rather than a subclass, so that it isn't stuck
46 to the particular underlying representation returned by Connection.cursor().
47 """
48
49 codes_for_integrityerror = (
50 1048, # Column cannot be null
51 1690, # BIGINT UNSIGNED value is out of range
52 3819, # CHECK constraint is violated
53 4025, # CHECK constraint failed
54 )
55
56 def __init__(self, cursor):
57 self.cursor = cursor
58
59 def execute(self, query, args=None):
60 try:
61 # args is None means no string interpolation
62 return self.cursor.execute(query, args)
63 except Database.OperationalError as e:
64 # Map some error codes to IntegrityError, since they seem to be
65 # misclassified and Plain would prefer the more logical place.
66 if e.args[0] in self.codes_for_integrityerror:
67 raise IntegrityError(*tuple(e.args))
68 raise
69
70 def executemany(self, query, args):
71 try:
72 return self.cursor.executemany(query, args)
73 except Database.OperationalError as e:
74 # Map some error codes to IntegrityError, since they seem to be
75 # misclassified and Plain would prefer the more logical place.
76 if e.args[0] in self.codes_for_integrityerror:
77 raise IntegrityError(*tuple(e.args))
78 raise
79
80 def __getattr__(self, attr):
81 return getattr(self.cursor, attr)
82
83 def __iter__(self):
84 return iter(self.cursor)
85
86
87class DatabaseWrapper(BaseDatabaseWrapper):
88 vendor = "mysql"
89 # This dictionary maps Field objects to their associated MySQL column
90 # types, as strings. Column-type strings can contain format strings; they'll
91 # be interpolated against the values of Field.__dict__ before being output.
92 # If a column type is set to None, it won't be included in the output.
93 data_types = {
94 "PrimaryKeyField": "bigint AUTO_INCREMENT",
95 "BinaryField": "longblob",
96 "BooleanField": "bool",
97 "CharField": "varchar(%(max_length)s)",
98 "DateField": "date",
99 "DateTimeField": "datetime(6)",
100 "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
101 "DurationField": "bigint",
102 "FloatField": "double precision",
103 "IntegerField": "integer",
104 "BigIntegerField": "bigint",
105 "IPAddressField": "char(15)",
106 "GenericIPAddressField": "char(39)",
107 "JSONField": "json",
108 "PositiveBigIntegerField": "bigint UNSIGNED",
109 "PositiveIntegerField": "integer UNSIGNED",
110 "PositiveSmallIntegerField": "smallint UNSIGNED",
111 "SmallIntegerField": "smallint",
112 "TextField": "longtext",
113 "TimeField": "time(6)",
114 "UUIDField": "char(32)",
115 }
116
117 # For these data types:
118 # - MySQL < 8.0.13 doesn't accept default values and implicitly treats them
119 # as nullable
120 # - all versions of MySQL and MariaDB don't support full width database
121 # indexes
122 _limited_data_types = (
123 "tinyblob",
124 "blob",
125 "mediumblob",
126 "longblob",
127 "tinytext",
128 "text",
129 "mediumtext",
130 "longtext",
131 "json",
132 )
133
134 operators = {
135 "exact": "= %s",
136 "iexact": "LIKE %s",
137 "contains": "LIKE BINARY %s",
138 "icontains": "LIKE %s",
139 "gt": "> %s",
140 "gte": ">= %s",
141 "lt": "< %s",
142 "lte": "<= %s",
143 "startswith": "LIKE BINARY %s",
144 "endswith": "LIKE BINARY %s",
145 "istartswith": "LIKE %s",
146 "iendswith": "LIKE %s",
147 }
148
149 # The patterns below are used to generate SQL pattern lookup clauses when
150 # the right-hand side of the lookup isn't a raw string (it might be an expression
151 # or the result of a bilateral transformation).
152 # In those cases, special characters for LIKE operators (e.g. \, *, _) should be
153 # escaped on database side.
154 #
155 # Note: we use str.format() here for readability as '%' is used as a wildcard for
156 # the LIKE operator.
157 pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')"
158 pattern_ops = {
159 "contains": "LIKE BINARY CONCAT('%%', {}, '%%')",
160 "icontains": "LIKE CONCAT('%%', {}, '%%')",
161 "startswith": "LIKE BINARY CONCAT({}, '%%')",
162 "istartswith": "LIKE CONCAT({}, '%%')",
163 "endswith": "LIKE BINARY CONCAT('%%', {})",
164 "iendswith": "LIKE CONCAT('%%', {})",
165 }
166
167 isolation_levels = {
168 "read uncommitted",
169 "read committed",
170 "repeatable read",
171 "serializable",
172 }
173
174 Database = Database
175 SchemaEditorClass = DatabaseSchemaEditor
176 # Classes instantiated in __init__().
177 client_class = DatabaseClient
178 creation_class = DatabaseCreation
179 features_class = DatabaseFeatures
180 introspection_class = DatabaseIntrospection
181 ops_class = DatabaseOperations
182 validation_class = DatabaseValidation
183
184 def get_database_version(self):
185 return self.mysql_version
186
187 def get_connection_params(self):
188 kwargs = {
189 "conv": plain_conversions,
190 "charset": "utf8",
191 }
192 settings_dict = self.settings_dict
193 if settings_dict["USER"]:
194 kwargs["user"] = settings_dict["USER"]
195 if settings_dict["NAME"]:
196 kwargs["database"] = settings_dict["NAME"]
197 if settings_dict["PASSWORD"]:
198 kwargs["password"] = settings_dict["PASSWORD"]
199 if settings_dict["HOST"].startswith("/"):
200 kwargs["unix_socket"] = settings_dict["HOST"]
201 elif settings_dict["HOST"]:
202 kwargs["host"] = settings_dict["HOST"]
203 if settings_dict["PORT"]:
204 kwargs["port"] = int(settings_dict["PORT"])
205 # We need the number of potentially affected rows after an
206 # "UPDATE", not the number of changed rows.
207 kwargs["client_flag"] = CLIENT.FOUND_ROWS
208 # Validate the transaction isolation level, if specified.
209 options = settings_dict["OPTIONS"].copy()
210 isolation_level = options.pop("isolation_level", "read committed")
211 if isolation_level:
212 isolation_level = isolation_level.lower()
213 if isolation_level not in self.isolation_levels:
214 raise ImproperlyConfigured(
215 "Invalid transaction isolation level '{}' specified.\n"
216 "Use one of {}, or None.".format(
217 isolation_level,
218 ", ".join(f"'{s}'" for s in sorted(self.isolation_levels)),
219 )
220 )
221 self.isolation_level = isolation_level
222 kwargs.update(options)
223 return kwargs
224
225 def get_new_connection(self, conn_params):
226 connection = Database.connect(**conn_params)
227 # bytes encoder in mysqlclient doesn't work and was added only to
228 # prevent KeyErrors in Plain < 2.0. We can remove this workaround when
229 # mysqlclient 2.1 becomes the minimal mysqlclient supported by Plain.
230 # See https://github.com/PyMySQL/mysqlclient/issues/489
231 if connection.encoders.get(bytes) is bytes:
232 connection.encoders.pop(bytes)
233 return connection
234
235 def init_connection_state(self):
236 super().init_connection_state()
237 assignments = []
238 if self.features.is_sql_auto_is_null_enabled:
239 # SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on
240 # a recently inserted row will return when the field is tested
241 # for NULL. Disabling this brings this aspect of MySQL in line
242 # with SQL standards.
243 assignments.append("SET SQL_AUTO_IS_NULL = 0")
244
245 if self.isolation_level:
246 assignments.append(
247 f"SET SESSION TRANSACTION ISOLATION LEVEL {self.isolation_level.upper()}"
248 )
249
250 if assignments:
251 with self.cursor() as cursor:
252 cursor.execute("; ".join(assignments))
253
254 def create_cursor(self, name=None):
255 cursor = self.connection.cursor()
256 return CursorWrapper(cursor)
257
258 def _rollback(self):
259 try:
260 BaseDatabaseWrapper._rollback(self)
261 except Database.NotSupportedError:
262 pass
263
264 def _set_autocommit(self, autocommit):
265 with self.wrap_database_errors:
266 self.connection.autocommit(autocommit)
267
268 def check_constraints(self, table_names=None):
269 """Check ``table_names`` for rows with invalid foreign key references."""
270 with self.cursor() as cursor:
271 if table_names is None:
272 table_names = self.introspection.table_names(cursor)
273 for table_name in table_names:
274 primary_key_column_name = self.introspection.get_primary_key_column(
275 cursor, table_name
276 )
277 if not primary_key_column_name:
278 continue
279 relations = self.introspection.get_relations(cursor, table_name)
280 for column_name, (
281 referenced_column_name,
282 referenced_table_name,
283 ) in relations.items():
284 cursor.execute(
285 f"""
286 SELECT REFERRING.`{primary_key_column_name}`, REFERRING.`{column_name}` FROM `{table_name}` as REFERRING
287 LEFT JOIN `{referenced_table_name}` as REFERRED
288 ON (REFERRING.`{column_name}` = REFERRED.`{referenced_column_name}`)
289 WHERE REFERRING.`{column_name}` IS NOT NULL AND REFERRED.`{referenced_column_name}` IS NULL
290 """
291 )
292 for bad_row in cursor.fetchall():
293 raise IntegrityError(
294 f"The row in table '{table_name}' with primary key '{bad_row[0]}' has an "
295 f"invalid foreign key: {table_name}.{column_name} contains a value '{bad_row[1]}' that "
296 f"does not have a corresponding value in {referenced_table_name}.{referenced_column_name}."
297 )
298
299 def is_usable(self):
300 try:
301 self.connection.ping()
302 except Database.Error:
303 return False
304 else:
305 return True
306
307 @cached_property
308 def display_name(self):
309 return "MariaDB" if self.mysql_is_mariadb else "MySQL"
310
311 @cached_property
312 def data_type_check_constraints(self):
313 if self.features.supports_column_check_constraints:
314 check_constraints = {
315 "PositiveBigIntegerField": "`%(column)s` >= 0",
316 "PositiveIntegerField": "`%(column)s` >= 0",
317 "PositiveSmallIntegerField": "`%(column)s` >= 0",
318 }
319 if self.mysql_is_mariadb and self.mysql_version < (10, 4, 3):
320 # MariaDB < 10.4.3 doesn't automatically use the JSON_VALID as
321 # a check constraint.
322 check_constraints["JSONField"] = "JSON_VALID(`%(column)s`)"
323 return check_constraints
324 return {}
325
326 @cached_property
327 def mysql_server_data(self):
328 with self.temporary_connection() as cursor:
329 # Select some server variables and test if the time zone
330 # definitions are installed. CONVERT_TZ returns NULL if 'UTC'
331 # timezone isn't loaded into the mysql.time_zone table.
332 cursor.execute(
333 """
334 SELECT VERSION(),
335 @@sql_mode,
336 @@default_storage_engine,
337 @@sql_auto_is_null,
338 @@lower_case_table_names,
339 CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL
340 """
341 )
342 row = cursor.fetchone()
343 return {
344 "version": row[0],
345 "sql_mode": row[1],
346 "default_storage_engine": row[2],
347 "sql_auto_is_null": bool(row[3]),
348 "lower_case_table_names": bool(row[4]),
349 "has_zoneinfo_database": bool(row[5]),
350 }
351
352 @cached_property
353 def mysql_server_info(self):
354 return self.mysql_server_data["version"]
355
356 @cached_property
357 def mysql_version(self):
358 match = server_version_re.match(self.mysql_server_info)
359 if not match:
360 raise Exception(
361 f"Unable to determine MySQL version from version string {self.mysql_server_info!r}"
362 )
363 return tuple(int(x) for x in match.groups())
364
365 @cached_property
366 def mysql_is_mariadb(self):
367 return "mariadb" in self.mysql_server_info.lower()
368
369 @cached_property
370 def sql_mode(self):
371 sql_mode = self.mysql_server_data["sql_mode"]
372 return set(sql_mode.split(",") if sql_mode else ())