Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2MySQL database backend for Plain.
  3
  4Requires mysqlclient: https://pypi.org/project/mysqlclient/
  5"""
  6
  7from functools import cached_property
  8
  9import MySQLdb as Database
 10from MySQLdb.constants import CLIENT, FIELD_TYPE
 11from MySQLdb.converters import conversions
 12
 13from plain.exceptions import ImproperlyConfigured
 14from plain.models.backends import utils as backend_utils
 15from plain.models.backends.base.base import BaseDatabaseWrapper
 16from plain.models.db import IntegrityError
 17from plain.utils.regex_helper import _lazy_re_compile
 18
 19from .client import DatabaseClient
 20from .creation import DatabaseCreation
 21from .features import DatabaseFeatures
 22from .introspection import DatabaseIntrospection
 23from .operations import DatabaseOperations
 24from .schema import DatabaseSchemaEditor
 25from .validation import DatabaseValidation
 26
 27# MySQLdb returns TIME columns as timedelta -- they are more like timedelta in
 28# terms of actual behavior as they are signed and include days -- and Plain
 29# expects time.
 30plain_conversions = {
 31    **conversions,
 32    **{FIELD_TYPE.TIME: backend_utils.typecast_time},
 33}
 34
 35# This should match the numerical portion of the version numbers (we can treat
 36# versions like 5.0.24 and 5.0.24a as the same).
 37server_version_re = _lazy_re_compile(r"(\d{1,2})\.(\d{1,2})\.(\d{1,2})")
 38
 39
 40class CursorWrapper:
 41    """
 42    A thin wrapper around MySQLdb's normal cursor class that catches particular
 43    exception instances and reraises them with the correct types.
 44
 45    Implemented as a wrapper, rather than a subclass, so that it isn't stuck
 46    to the particular underlying representation returned by Connection.cursor().
 47    """
 48
 49    codes_for_integrityerror = (
 50        1048,  # Column cannot be null
 51        1690,  # BIGINT UNSIGNED value is out of range
 52        3819,  # CHECK constraint is violated
 53        4025,  # CHECK constraint failed
 54    )
 55
 56    def __init__(self, cursor):
 57        self.cursor = cursor
 58
 59    def execute(self, query, args=None):
 60        try:
 61            # args is None means no string interpolation
 62            return self.cursor.execute(query, args)
 63        except Database.OperationalError as e:
 64            # Map some error codes to IntegrityError, since they seem to be
 65            # misclassified and Plain would prefer the more logical place.
 66            if e.args[0] in self.codes_for_integrityerror:
 67                raise IntegrityError(*tuple(e.args))
 68            raise
 69
 70    def executemany(self, query, args):
 71        try:
 72            return self.cursor.executemany(query, args)
 73        except Database.OperationalError as e:
 74            # Map some error codes to IntegrityError, since they seem to be
 75            # misclassified and Plain would prefer the more logical place.
 76            if e.args[0] in self.codes_for_integrityerror:
 77                raise IntegrityError(*tuple(e.args))
 78            raise
 79
 80    def __getattr__(self, attr):
 81        return getattr(self.cursor, attr)
 82
 83    def __iter__(self):
 84        return iter(self.cursor)
 85
 86
 87class DatabaseWrapper(BaseDatabaseWrapper):
 88    vendor = "mysql"
 89    # This dictionary maps Field objects to their associated MySQL column
 90    # types, as strings. Column-type strings can contain format strings; they'll
 91    # be interpolated against the values of Field.__dict__ before being output.
 92    # If a column type is set to None, it won't be included in the output.
 93    data_types = {
 94        "PrimaryKeyField": "bigint AUTO_INCREMENT",
 95        "BinaryField": "longblob",
 96        "BooleanField": "bool",
 97        "CharField": "varchar(%(max_length)s)",
 98        "DateField": "date",
 99        "DateTimeField": "datetime(6)",
100        "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
101        "DurationField": "bigint",
102        "FloatField": "double precision",
103        "IntegerField": "integer",
104        "BigIntegerField": "bigint",
105        "IPAddressField": "char(15)",
106        "GenericIPAddressField": "char(39)",
107        "JSONField": "json",
108        "PositiveBigIntegerField": "bigint UNSIGNED",
109        "PositiveIntegerField": "integer UNSIGNED",
110        "PositiveSmallIntegerField": "smallint UNSIGNED",
111        "SmallIntegerField": "smallint",
112        "TextField": "longtext",
113        "TimeField": "time(6)",
114        "UUIDField": "char(32)",
115    }
116
117    # For these data types:
118    # - MySQL < 8.0.13 doesn't accept default values and implicitly treats them
119    #   as nullable
120    # - all versions of MySQL and MariaDB don't support full width database
121    #   indexes
122    _limited_data_types = (
123        "tinyblob",
124        "blob",
125        "mediumblob",
126        "longblob",
127        "tinytext",
128        "text",
129        "mediumtext",
130        "longtext",
131        "json",
132    )
133
134    operators = {
135        "exact": "= %s",
136        "iexact": "LIKE %s",
137        "contains": "LIKE BINARY %s",
138        "icontains": "LIKE %s",
139        "gt": "> %s",
140        "gte": ">= %s",
141        "lt": "< %s",
142        "lte": "<= %s",
143        "startswith": "LIKE BINARY %s",
144        "endswith": "LIKE BINARY %s",
145        "istartswith": "LIKE %s",
146        "iendswith": "LIKE %s",
147    }
148
149    # The patterns below are used to generate SQL pattern lookup clauses when
150    # the right-hand side of the lookup isn't a raw string (it might be an expression
151    # or the result of a bilateral transformation).
152    # In those cases, special characters for LIKE operators (e.g. \, *, _) should be
153    # escaped on database side.
154    #
155    # Note: we use str.format() here for readability as '%' is used as a wildcard for
156    # the LIKE operator.
157    pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')"
158    pattern_ops = {
159        "contains": "LIKE BINARY CONCAT('%%', {}, '%%')",
160        "icontains": "LIKE CONCAT('%%', {}, '%%')",
161        "startswith": "LIKE BINARY CONCAT({}, '%%')",
162        "istartswith": "LIKE CONCAT({}, '%%')",
163        "endswith": "LIKE BINARY CONCAT('%%', {})",
164        "iendswith": "LIKE CONCAT('%%', {})",
165    }
166
167    isolation_levels = {
168        "read uncommitted",
169        "read committed",
170        "repeatable read",
171        "serializable",
172    }
173
174    Database = Database
175    SchemaEditorClass = DatabaseSchemaEditor
176    # Classes instantiated in __init__().
177    client_class = DatabaseClient
178    creation_class = DatabaseCreation
179    features_class = DatabaseFeatures
180    introspection_class = DatabaseIntrospection
181    ops_class = DatabaseOperations
182    validation_class = DatabaseValidation
183
184    def get_database_version(self):
185        return self.mysql_version
186
187    def get_connection_params(self):
188        kwargs = {
189            "conv": plain_conversions,
190            "charset": "utf8",
191        }
192        settings_dict = self.settings_dict
193        if settings_dict["USER"]:
194            kwargs["user"] = settings_dict["USER"]
195        if settings_dict["NAME"]:
196            kwargs["database"] = settings_dict["NAME"]
197        if settings_dict["PASSWORD"]:
198            kwargs["password"] = settings_dict["PASSWORD"]
199        if settings_dict["HOST"].startswith("/"):
200            kwargs["unix_socket"] = settings_dict["HOST"]
201        elif settings_dict["HOST"]:
202            kwargs["host"] = settings_dict["HOST"]
203        if settings_dict["PORT"]:
204            kwargs["port"] = int(settings_dict["PORT"])
205        # We need the number of potentially affected rows after an
206        # "UPDATE", not the number of changed rows.
207        kwargs["client_flag"] = CLIENT.FOUND_ROWS
208        # Validate the transaction isolation level, if specified.
209        options = settings_dict["OPTIONS"].copy()
210        isolation_level = options.pop("isolation_level", "read committed")
211        if isolation_level:
212            isolation_level = isolation_level.lower()
213            if isolation_level not in self.isolation_levels:
214                raise ImproperlyConfigured(
215                    "Invalid transaction isolation level '{}' specified.\n"
216                    "Use one of {}, or None.".format(
217                        isolation_level,
218                        ", ".join(f"'{s}'" for s in sorted(self.isolation_levels)),
219                    )
220                )
221        self.isolation_level = isolation_level
222        kwargs.update(options)
223        return kwargs
224
225    def get_new_connection(self, conn_params):
226        connection = Database.connect(**conn_params)
227        # bytes encoder in mysqlclient doesn't work and was added only to
228        # prevent KeyErrors in Plain < 2.0. We can remove this workaround when
229        # mysqlclient 2.1 becomes the minimal mysqlclient supported by Plain.
230        # See https://github.com/PyMySQL/mysqlclient/issues/489
231        if connection.encoders.get(bytes) is bytes:
232            connection.encoders.pop(bytes)
233        return connection
234
235    def init_connection_state(self):
236        super().init_connection_state()
237        assignments = []
238        if self.features.is_sql_auto_is_null_enabled:
239            # SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on
240            # a recently inserted row will return when the field is tested
241            # for NULL. Disabling this brings this aspect of MySQL in line
242            # with SQL standards.
243            assignments.append("SET SQL_AUTO_IS_NULL = 0")
244
245        if self.isolation_level:
246            assignments.append(
247                f"SET SESSION TRANSACTION ISOLATION LEVEL {self.isolation_level.upper()}"
248            )
249
250        if assignments:
251            with self.cursor() as cursor:
252                cursor.execute("; ".join(assignments))
253
254    def create_cursor(self, name=None):
255        cursor = self.connection.cursor()
256        return CursorWrapper(cursor)
257
258    def _rollback(self):
259        try:
260            BaseDatabaseWrapper._rollback(self)
261        except Database.NotSupportedError:
262            pass
263
264    def _set_autocommit(self, autocommit):
265        with self.wrap_database_errors:
266            self.connection.autocommit(autocommit)
267
268    def check_constraints(self, table_names=None):
269        """Check ``table_names`` for rows with invalid foreign key references."""
270        with self.cursor() as cursor:
271            if table_names is None:
272                table_names = self.introspection.table_names(cursor)
273            for table_name in table_names:
274                primary_key_column_name = self.introspection.get_primary_key_column(
275                    cursor, table_name
276                )
277                if not primary_key_column_name:
278                    continue
279                relations = self.introspection.get_relations(cursor, table_name)
280                for column_name, (
281                    referenced_column_name,
282                    referenced_table_name,
283                ) in relations.items():
284                    cursor.execute(
285                        f"""
286                        SELECT REFERRING.`{primary_key_column_name}`, REFERRING.`{column_name}` FROM `{table_name}` as REFERRING
287                        LEFT JOIN `{referenced_table_name}` as REFERRED
288                        ON (REFERRING.`{column_name}` = REFERRED.`{referenced_column_name}`)
289                        WHERE REFERRING.`{column_name}` IS NOT NULL AND REFERRED.`{referenced_column_name}` IS NULL
290                        """
291                    )
292                    for bad_row in cursor.fetchall():
293                        raise IntegrityError(
294                            f"The row in table '{table_name}' with primary key '{bad_row[0]}' has an "
295                            f"invalid foreign key: {table_name}.{column_name} contains a value '{bad_row[1]}' that "
296                            f"does not have a corresponding value in {referenced_table_name}.{referenced_column_name}."
297                        )
298
299    def is_usable(self):
300        try:
301            self.connection.ping()
302        except Database.Error:
303            return False
304        else:
305            return True
306
307    @cached_property
308    def display_name(self):
309        return "MariaDB" if self.mysql_is_mariadb else "MySQL"
310
311    @cached_property
312    def data_type_check_constraints(self):
313        if self.features.supports_column_check_constraints:
314            check_constraints = {
315                "PositiveBigIntegerField": "`%(column)s` >= 0",
316                "PositiveIntegerField": "`%(column)s` >= 0",
317                "PositiveSmallIntegerField": "`%(column)s` >= 0",
318            }
319            if self.mysql_is_mariadb and self.mysql_version < (10, 4, 3):
320                # MariaDB < 10.4.3 doesn't automatically use the JSON_VALID as
321                # a check constraint.
322                check_constraints["JSONField"] = "JSON_VALID(`%(column)s`)"
323            return check_constraints
324        return {}
325
326    @cached_property
327    def mysql_server_data(self):
328        with self.temporary_connection() as cursor:
329            # Select some server variables and test if the time zone
330            # definitions are installed. CONVERT_TZ returns NULL if 'UTC'
331            # timezone isn't loaded into the mysql.time_zone table.
332            cursor.execute(
333                """
334                SELECT VERSION(),
335                       @@sql_mode,
336                       @@default_storage_engine,
337                       @@sql_auto_is_null,
338                       @@lower_case_table_names,
339                       CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL
340            """
341            )
342            row = cursor.fetchone()
343        return {
344            "version": row[0],
345            "sql_mode": row[1],
346            "default_storage_engine": row[2],
347            "sql_auto_is_null": bool(row[3]),
348            "lower_case_table_names": bool(row[4]),
349            "has_zoneinfo_database": bool(row[5]),
350        }
351
352    @cached_property
353    def mysql_server_info(self):
354        return self.mysql_server_data["version"]
355
356    @cached_property
357    def mysql_version(self):
358        match = server_version_re.match(self.mysql_server_info)
359        if not match:
360            raise Exception(
361                f"Unable to determine MySQL version from version string {self.mysql_server_info!r}"
362            )
363        return tuple(int(x) for x in match.groups())
364
365    @cached_property
366    def mysql_is_mariadb(self):
367        return "mariadb" in self.mysql_server_info.lower()
368
369    @cached_property
370    def sql_mode(self):
371        sql_mode = self.mysql_server_data["sql_mode"]
372        return set(sql_mode.split(",") if sql_mode else ())