1"""
2MySQL database backend for Plain.
3
4Requires mysqlclient: https://pypi.org/project/mysqlclient/
5"""
6from plain.exceptions import ImproperlyConfigured
7from plain.models.backends import utils as backend_utils
8from plain.models.backends.base.base import BaseDatabaseWrapper
9from plain.models.db import IntegrityError
10from plain.utils.functional import cached_property
11from plain.utils.regex_helper import _lazy_re_compile
12
13try:
14 import MySQLdb as Database
15except ImportError as err:
16 raise ImproperlyConfigured(
17 "Error loading MySQLdb module.\nDid you install mysqlclient?"
18 ) from err
19
20from MySQLdb.constants import CLIENT, FIELD_TYPE
21from MySQLdb.converters import conversions
22
23# Some of these import MySQLdb, so import them after checking if it's installed.
24from .client import DatabaseClient
25from .creation import DatabaseCreation
26from .features import DatabaseFeatures
27from .introspection import DatabaseIntrospection
28from .operations import DatabaseOperations
29from .schema import DatabaseSchemaEditor
30from .validation import DatabaseValidation
31
32version = Database.version_info
33if version < (1, 4, 3):
34 raise ImproperlyConfigured(
35 "mysqlclient 1.4.3 or newer is required; you have %s." % Database.__version__
36 )
37
38
39# MySQLdb returns TIME columns as timedelta -- they are more like timedelta in
40# terms of actual behavior as they are signed and include days -- and Plain
41# expects time.
42plain_conversions = {
43 **conversions,
44 **{FIELD_TYPE.TIME: backend_utils.typecast_time},
45}
46
47# This should match the numerical portion of the version numbers (we can treat
48# versions like 5.0.24 and 5.0.24a as the same).
49server_version_re = _lazy_re_compile(r"(\d{1,2})\.(\d{1,2})\.(\d{1,2})")
50
51
52class CursorWrapper:
53 """
54 A thin wrapper around MySQLdb's normal cursor class that catches particular
55 exception instances and reraises them with the correct types.
56
57 Implemented as a wrapper, rather than a subclass, so that it isn't stuck
58 to the particular underlying representation returned by Connection.cursor().
59 """
60
61 codes_for_integrityerror = (
62 1048, # Column cannot be null
63 1690, # BIGINT UNSIGNED value is out of range
64 3819, # CHECK constraint is violated
65 4025, # CHECK constraint failed
66 )
67
68 def __init__(self, cursor):
69 self.cursor = cursor
70
71 def execute(self, query, args=None):
72 try:
73 # args is None means no string interpolation
74 return self.cursor.execute(query, args)
75 except Database.OperationalError as e:
76 # Map some error codes to IntegrityError, since they seem to be
77 # misclassified and Plain would prefer the more logical place.
78 if e.args[0] in self.codes_for_integrityerror:
79 raise IntegrityError(*tuple(e.args))
80 raise
81
82 def executemany(self, query, args):
83 try:
84 return self.cursor.executemany(query, args)
85 except Database.OperationalError as e:
86 # Map some error codes to IntegrityError, since they seem to be
87 # misclassified and Plain would prefer the more logical place.
88 if e.args[0] in self.codes_for_integrityerror:
89 raise IntegrityError(*tuple(e.args))
90 raise
91
92 def __getattr__(self, attr):
93 return getattr(self.cursor, attr)
94
95 def __iter__(self):
96 return iter(self.cursor)
97
98
99class DatabaseWrapper(BaseDatabaseWrapper):
100 vendor = "mysql"
101 # This dictionary maps Field objects to their associated MySQL column
102 # types, as strings. Column-type strings can contain format strings; they'll
103 # be interpolated against the values of Field.__dict__ before being output.
104 # If a column type is set to None, it won't be included in the output.
105 data_types = {
106 "AutoField": "integer AUTO_INCREMENT",
107 "BigAutoField": "bigint AUTO_INCREMENT",
108 "BinaryField": "longblob",
109 "BooleanField": "bool",
110 "CharField": "varchar(%(max_length)s)",
111 "DateField": "date",
112 "DateTimeField": "datetime(6)",
113 "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
114 "DurationField": "bigint",
115 "FloatField": "double precision",
116 "IntegerField": "integer",
117 "BigIntegerField": "bigint",
118 "IPAddressField": "char(15)",
119 "GenericIPAddressField": "char(39)",
120 "JSONField": "json",
121 "OneToOneField": "integer",
122 "PositiveBigIntegerField": "bigint UNSIGNED",
123 "PositiveIntegerField": "integer UNSIGNED",
124 "PositiveSmallIntegerField": "smallint UNSIGNED",
125 "SlugField": "varchar(%(max_length)s)",
126 "SmallAutoField": "smallint AUTO_INCREMENT",
127 "SmallIntegerField": "smallint",
128 "TextField": "longtext",
129 "TimeField": "time(6)",
130 "UUIDField": "char(32)",
131 }
132
133 # For these data types:
134 # - MySQL < 8.0.13 doesn't accept default values and implicitly treats them
135 # as nullable
136 # - all versions of MySQL and MariaDB don't support full width database
137 # indexes
138 _limited_data_types = (
139 "tinyblob",
140 "blob",
141 "mediumblob",
142 "longblob",
143 "tinytext",
144 "text",
145 "mediumtext",
146 "longtext",
147 "json",
148 )
149
150 operators = {
151 "exact": "= %s",
152 "iexact": "LIKE %s",
153 "contains": "LIKE BINARY %s",
154 "icontains": "LIKE %s",
155 "gt": "> %s",
156 "gte": ">= %s",
157 "lt": "< %s",
158 "lte": "<= %s",
159 "startswith": "LIKE BINARY %s",
160 "endswith": "LIKE BINARY %s",
161 "istartswith": "LIKE %s",
162 "iendswith": "LIKE %s",
163 }
164
165 # The patterns below are used to generate SQL pattern lookup clauses when
166 # the right-hand side of the lookup isn't a raw string (it might be an expression
167 # or the result of a bilateral transformation).
168 # In those cases, special characters for LIKE operators (e.g. \, *, _) should be
169 # escaped on database side.
170 #
171 # Note: we use str.format() here for readability as '%' is used as a wildcard for
172 # the LIKE operator.
173 pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')"
174 pattern_ops = {
175 "contains": "LIKE BINARY CONCAT('%%', {}, '%%')",
176 "icontains": "LIKE CONCAT('%%', {}, '%%')",
177 "startswith": "LIKE BINARY CONCAT({}, '%%')",
178 "istartswith": "LIKE CONCAT({}, '%%')",
179 "endswith": "LIKE BINARY CONCAT('%%', {})",
180 "iendswith": "LIKE CONCAT('%%', {})",
181 }
182
183 isolation_levels = {
184 "read uncommitted",
185 "read committed",
186 "repeatable read",
187 "serializable",
188 }
189
190 Database = Database
191 SchemaEditorClass = DatabaseSchemaEditor
192 # Classes instantiated in __init__().
193 client_class = DatabaseClient
194 creation_class = DatabaseCreation
195 features_class = DatabaseFeatures
196 introspection_class = DatabaseIntrospection
197 ops_class = DatabaseOperations
198 validation_class = DatabaseValidation
199
200 def get_database_version(self):
201 return self.mysql_version
202
203 def get_connection_params(self):
204 kwargs = {
205 "conv": plain_conversions,
206 "charset": "utf8",
207 }
208 settings_dict = self.settings_dict
209 if settings_dict["USER"]:
210 kwargs["user"] = settings_dict["USER"]
211 if settings_dict["NAME"]:
212 kwargs["database"] = settings_dict["NAME"]
213 if settings_dict["PASSWORD"]:
214 kwargs["password"] = settings_dict["PASSWORD"]
215 if settings_dict["HOST"].startswith("/"):
216 kwargs["unix_socket"] = settings_dict["HOST"]
217 elif settings_dict["HOST"]:
218 kwargs["host"] = settings_dict["HOST"]
219 if settings_dict["PORT"]:
220 kwargs["port"] = int(settings_dict["PORT"])
221 # We need the number of potentially affected rows after an
222 # "UPDATE", not the number of changed rows.
223 kwargs["client_flag"] = CLIENT.FOUND_ROWS
224 # Validate the transaction isolation level, if specified.
225 options = settings_dict["OPTIONS"].copy()
226 isolation_level = options.pop("isolation_level", "read committed")
227 if isolation_level:
228 isolation_level = isolation_level.lower()
229 if isolation_level not in self.isolation_levels:
230 raise ImproperlyConfigured(
231 "Invalid transaction isolation level '{}' specified.\n"
232 "Use one of {}, or None.".format(
233 isolation_level,
234 ", ".join("'%s'" % s for s in sorted(self.isolation_levels)),
235 )
236 )
237 self.isolation_level = isolation_level
238 kwargs.update(options)
239 return kwargs
240
241 def get_new_connection(self, conn_params):
242 connection = Database.connect(**conn_params)
243 # bytes encoder in mysqlclient doesn't work and was added only to
244 # prevent KeyErrors in Plain < 2.0. We can remove this workaround when
245 # mysqlclient 2.1 becomes the minimal mysqlclient supported by Plain.
246 # See https://github.com/PyMySQL/mysqlclient/issues/489
247 if connection.encoders.get(bytes) is bytes:
248 connection.encoders.pop(bytes)
249 return connection
250
251 def init_connection_state(self):
252 super().init_connection_state()
253 assignments = []
254 if self.features.is_sql_auto_is_null_enabled:
255 # SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on
256 # a recently inserted row will return when the field is tested
257 # for NULL. Disabling this brings this aspect of MySQL in line
258 # with SQL standards.
259 assignments.append("SET SQL_AUTO_IS_NULL = 0")
260
261 if self.isolation_level:
262 assignments.append(
263 "SET SESSION TRANSACTION ISOLATION LEVEL %s"
264 % self.isolation_level.upper()
265 )
266
267 if assignments:
268 with self.cursor() as cursor:
269 cursor.execute("; ".join(assignments))
270
271 def create_cursor(self, name=None):
272 cursor = self.connection.cursor()
273 return CursorWrapper(cursor)
274
275 def _rollback(self):
276 try:
277 BaseDatabaseWrapper._rollback(self)
278 except Database.NotSupportedError:
279 pass
280
281 def _set_autocommit(self, autocommit):
282 with self.wrap_database_errors:
283 self.connection.autocommit(autocommit)
284
285 def disable_constraint_checking(self):
286 """
287 Disable foreign key checks, primarily for use in adding rows with
288 forward references. Always return True to indicate constraint checks
289 need to be re-enabled.
290 """
291 with self.cursor() as cursor:
292 cursor.execute("SET foreign_key_checks=0")
293 return True
294
295 def enable_constraint_checking(self):
296 """
297 Re-enable foreign key checks after they have been disabled.
298 """
299 # Override needs_rollback in case constraint_checks_disabled is
300 # nested inside transaction.atomic.
301 self.needs_rollback, needs_rollback = False, self.needs_rollback
302 try:
303 with self.cursor() as cursor:
304 cursor.execute("SET foreign_key_checks=1")
305 finally:
306 self.needs_rollback = needs_rollback
307
308 def check_constraints(self, table_names=None):
309 """
310 Check each table name in `table_names` for rows with invalid foreign
311 key references. This method is intended to be used in conjunction with
312 `disable_constraint_checking()` and `enable_constraint_checking()`, to
313 determine if rows with invalid references were entered while constraint
314 checks were off.
315 """
316 with self.cursor() as cursor:
317 if table_names is None:
318 table_names = self.introspection.table_names(cursor)
319 for table_name in table_names:
320 primary_key_column_name = self.introspection.get_primary_key_column(
321 cursor, table_name
322 )
323 if not primary_key_column_name:
324 continue
325 relations = self.introspection.get_relations(cursor, table_name)
326 for column_name, (
327 referenced_column_name,
328 referenced_table_name,
329 ) in relations.items():
330 cursor.execute(
331 """
332 SELECT REFERRING.`{}`, REFERRING.`{}` FROM `{}` as REFERRING
333 LEFT JOIN `{}` as REFERRED
334 ON (REFERRING.`{}` = REFERRED.`{}`)
335 WHERE REFERRING.`{}` IS NOT NULL AND REFERRED.`{}` IS NULL
336 """.format(
337 primary_key_column_name,
338 column_name,
339 table_name,
340 referenced_table_name,
341 column_name,
342 referenced_column_name,
343 column_name,
344 referenced_column_name,
345 )
346 )
347 for bad_row in cursor.fetchall():
348 raise IntegrityError(
349 "The row in table '{}' with primary key '{}' has an "
350 "invalid foreign key: {}.{} contains a value '{}' that "
351 "does not have a corresponding value in {}.{}.".format(
352 table_name,
353 bad_row[0],
354 table_name,
355 column_name,
356 bad_row[1],
357 referenced_table_name,
358 referenced_column_name,
359 )
360 )
361
362 def is_usable(self):
363 try:
364 self.connection.ping()
365 except Database.Error:
366 return False
367 else:
368 return True
369
370 @cached_property
371 def display_name(self):
372 return "MariaDB" if self.mysql_is_mariadb else "MySQL"
373
374 @cached_property
375 def data_type_check_constraints(self):
376 if self.features.supports_column_check_constraints:
377 check_constraints = {
378 "PositiveBigIntegerField": "`%(column)s` >= 0",
379 "PositiveIntegerField": "`%(column)s` >= 0",
380 "PositiveSmallIntegerField": "`%(column)s` >= 0",
381 }
382 if self.mysql_is_mariadb and self.mysql_version < (10, 4, 3):
383 # MariaDB < 10.4.3 doesn't automatically use the JSON_VALID as
384 # a check constraint.
385 check_constraints["JSONField"] = "JSON_VALID(`%(column)s`)"
386 return check_constraints
387 return {}
388
389 @cached_property
390 def mysql_server_data(self):
391 with self.temporary_connection() as cursor:
392 # Select some server variables and test if the time zone
393 # definitions are installed. CONVERT_TZ returns NULL if 'UTC'
394 # timezone isn't loaded into the mysql.time_zone table.
395 cursor.execute(
396 """
397 SELECT VERSION(),
398 @@sql_mode,
399 @@default_storage_engine,
400 @@sql_auto_is_null,
401 @@lower_case_table_names,
402 CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL
403 """
404 )
405 row = cursor.fetchone()
406 return {
407 "version": row[0],
408 "sql_mode": row[1],
409 "default_storage_engine": row[2],
410 "sql_auto_is_null": bool(row[3]),
411 "lower_case_table_names": bool(row[4]),
412 "has_zoneinfo_database": bool(row[5]),
413 }
414
415 @cached_property
416 def mysql_server_info(self):
417 return self.mysql_server_data["version"]
418
419 @cached_property
420 def mysql_version(self):
421 match = server_version_re.match(self.mysql_server_info)
422 if not match:
423 raise Exception(
424 "Unable to determine MySQL version from version string %r"
425 % self.mysql_server_info
426 )
427 return tuple(int(x) for x in match.groups())
428
429 @cached_property
430 def mysql_is_mariadb(self):
431 return "mariadb" in self.mysql_server_info.lower()
432
433 @cached_property
434 def sql_mode(self):
435 sql_mode = self.mysql_server_data["sql_mode"]
436 return set(sql_mode.split(",") if sql_mode else ())