Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2MySQL database backend for Plain.
  3
  4Requires mysqlclient: https://pypi.org/project/mysqlclient/
  5"""
  6
  7from __future__ import annotations
  8
  9from functools import cached_property
 10from typing import Any
 11
 12import MySQLdb as Database
 13from MySQLdb.constants import CLIENT, FIELD_TYPE
 14from MySQLdb.converters import conversions
 15
 16from plain.exceptions import ImproperlyConfigured
 17from plain.models.backends import utils as backend_utils
 18from plain.models.backends.base.base import BaseDatabaseWrapper
 19from plain.models.db import IntegrityError
 20from plain.utils.regex_helper import _lazy_re_compile
 21
 22# With mysqlclient stubs, we can now type the connection
 23try:
 24    from MySQLdb.connections import Connection as MySQLConnection
 25except ImportError:
 26    MySQLConnection = Any  # type: ignore[misc, assignment]
 27
 28from .client import DatabaseClient
 29from .creation import DatabaseCreation
 30from .features import DatabaseFeatures
 31from .introspection import DatabaseIntrospection
 32from .operations import DatabaseOperations
 33from .schema import DatabaseSchemaEditor
 34from .validation import DatabaseValidation
 35
 36# MySQLdb returns TIME columns as timedelta -- they are more like timedelta in
 37# terms of actual behavior as they are signed and include days -- and Plain
 38# expects time.
 39plain_conversions = {
 40    **conversions,
 41    **{FIELD_TYPE.TIME: backend_utils.typecast_time},
 42}
 43
 44# This should match the numerical portion of the version numbers (we can treat
 45# versions like 5.0.24 and 5.0.24a as the same).
 46server_version_re = _lazy_re_compile(r"(\d{1,2})\.(\d{1,2})\.(\d{1,2})")
 47
 48
 49class CursorWrapper:
 50    """
 51    A thin wrapper around MySQLdb's normal cursor class that catches particular
 52    exception instances and reraises them with the correct types.
 53
 54    Implemented as a wrapper, rather than a subclass, so that it isn't stuck
 55    to the particular underlying representation returned by Connection.cursor().
 56    """
 57
 58    codes_for_integrityerror = (
 59        1048,  # Column cannot be null
 60        1690,  # BIGINT UNSIGNED value is out of range
 61        3819,  # CHECK constraint is violated
 62        4025,  # CHECK constraint failed
 63    )
 64
 65    def __init__(self, cursor: Any) -> None:
 66        self.cursor = cursor
 67
 68    def execute(self, query: str, args: Any = None) -> int:
 69        try:
 70            # args is None means no string interpolation
 71            return self.cursor.execute(query, args)
 72        except Database.OperationalError as e:
 73            # Map some error codes to IntegrityError, since they seem to be
 74            # misclassified and Plain would prefer the more logical place.
 75            if e.args[0] in self.codes_for_integrityerror:
 76                raise IntegrityError(*tuple(e.args))
 77            raise
 78
 79    def executemany(self, query: str, args: Any) -> int:
 80        try:
 81            return self.cursor.executemany(query, args)
 82        except Database.OperationalError as e:
 83            # Map some error codes to IntegrityError, since they seem to be
 84            # misclassified and Plain would prefer the more logical place.
 85            if e.args[0] in self.codes_for_integrityerror:
 86                raise IntegrityError(*tuple(e.args))
 87            raise
 88
 89    def __getattr__(self, attr: str) -> Any:
 90        return getattr(self.cursor, attr)
 91
 92    def __iter__(self) -> Any:
 93        return iter(self.cursor)
 94
 95
 96class MySQLDatabaseWrapper(BaseDatabaseWrapper):
 97    # Type checker hints: narrow base class attribute types to backend-specific classes
 98    ops: DatabaseOperations
 99    features: DatabaseFeatures
100    introspection: DatabaseIntrospection
101    creation: DatabaseCreation
102
103    vendor = "mysql"
104    # This dictionary maps Field objects to their associated MySQL column
105    # types, as strings. Column-type strings can contain format strings; they'll
106    # be interpolated against the values of Field.__dict__ before being output.
107    # If a column type is set to None, it won't be included in the output.
108    data_types = {
109        "PrimaryKeyField": "bigint AUTO_INCREMENT",
110        "BinaryField": "longblob",
111        "BooleanField": "bool",
112        "CharField": "varchar(%(max_length)s)",
113        "DateField": "date",
114        "DateTimeField": "datetime(6)",
115        "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
116        "DurationField": "bigint",
117        "FloatField": "double precision",
118        "IntegerField": "integer",
119        "BigIntegerField": "bigint",
120        "GenericIPAddressField": "char(39)",
121        "JSONField": "json",
122        "PositiveBigIntegerField": "bigint UNSIGNED",
123        "PositiveIntegerField": "integer UNSIGNED",
124        "PositiveSmallIntegerField": "smallint UNSIGNED",
125        "SmallIntegerField": "smallint",
126        "TextField": "longtext",
127        "TimeField": "time(6)",
128        "UUIDField": "char(32)",
129    }
130
131    # For these data types:
132    # - MySQL < 8.0.13 doesn't accept default values and implicitly treats them
133    #   as nullable
134    # - all versions of MySQL and MariaDB don't support full width database
135    #   indexes
136    _limited_data_types = (
137        "tinyblob",
138        "blob",
139        "mediumblob",
140        "longblob",
141        "tinytext",
142        "text",
143        "mediumtext",
144        "longtext",
145        "json",
146    )
147
148    operators = {
149        "exact": "= %s",
150        "iexact": "LIKE %s",
151        "contains": "LIKE BINARY %s",
152        "icontains": "LIKE %s",
153        "gt": "> %s",
154        "gte": ">= %s",
155        "lt": "< %s",
156        "lte": "<= %s",
157        "startswith": "LIKE BINARY %s",
158        "endswith": "LIKE BINARY %s",
159        "istartswith": "LIKE %s",
160        "iendswith": "LIKE %s",
161    }
162
163    # The patterns below are used to generate SQL pattern lookup clauses when
164    # the right-hand side of the lookup isn't a raw string (it might be an expression
165    # or the result of a bilateral transformation).
166    # In those cases, special characters for LIKE operators (e.g. \, *, _) should be
167    # escaped on database side.
168    #
169    # Note: we use str.format() here for readability as '%' is used as a wildcard for
170    # the LIKE operator.
171    pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')"
172    pattern_ops = {
173        "contains": "LIKE BINARY CONCAT('%%', {}, '%%')",
174        "icontains": "LIKE CONCAT('%%', {}, '%%')",
175        "startswith": "LIKE BINARY CONCAT({}, '%%')",
176        "istartswith": "LIKE CONCAT({}, '%%')",
177        "endswith": "LIKE BINARY CONCAT('%%', {})",
178        "iendswith": "LIKE CONCAT('%%', {})",
179    }
180
181    isolation_levels = {
182        "read uncommitted",
183        "read committed",
184        "repeatable read",
185        "serializable",
186    }
187
188    Database = Database
189    SchemaEditorClass = DatabaseSchemaEditor
190    # Classes instantiated in __init__().
191    client_class = DatabaseClient
192    creation_class = DatabaseCreation
193    features_class = DatabaseFeatures
194    introspection_class = DatabaseIntrospection
195    ops_class = DatabaseOperations
196    validation_class = DatabaseValidation
197
198    def get_database_version(self) -> tuple[int, int, int]:
199        return self.mysql_version
200
201    def get_connection_params(self) -> dict[str, Any]:
202        kwargs: dict[str, Any] = {
203            "conv": plain_conversions,
204            "charset": "utf8",
205        }
206        settings_dict = self.settings_dict
207        if settings_dict["USER"]:
208            kwargs["user"] = settings_dict["USER"]
209        if settings_dict["NAME"]:
210            kwargs["database"] = settings_dict["NAME"]
211        if settings_dict["PASSWORD"]:
212            kwargs["password"] = settings_dict["PASSWORD"]
213        if settings_dict["HOST"].startswith("/"):
214            kwargs["unix_socket"] = settings_dict["HOST"]
215        elif settings_dict["HOST"]:
216            kwargs["host"] = settings_dict["HOST"]
217        if settings_dict["PORT"]:
218            kwargs["port"] = int(settings_dict["PORT"])
219        # We need the number of potentially affected rows after an
220        # "UPDATE", not the number of changed rows.
221        kwargs["client_flag"] = CLIENT.FOUND_ROWS
222        # Validate the transaction isolation level, if specified.
223        options = settings_dict["OPTIONS"].copy()
224        isolation_level = options.pop("isolation_level", "read committed")
225        if isolation_level:
226            isolation_level = isolation_level.lower()
227            if isolation_level not in self.isolation_levels:
228                raise ImproperlyConfigured(
229                    "Invalid transaction isolation level '{}' specified.\n"
230                    "Use one of {}, or None.".format(
231                        isolation_level,
232                        ", ".join(f"'{s}'" for s in sorted(self.isolation_levels)),
233                    )
234                )
235        self.isolation_level = isolation_level
236        kwargs.update(options)
237        return kwargs
238
239    def get_new_connection(self, conn_params: dict[str, Any]) -> MySQLConnection:
240        connection = Database.connect(**conn_params)
241        # bytes encoder in mysqlclient doesn't work and was added only to
242        # prevent KeyErrors in Plain < 2.0. We can remove this workaround when
243        # mysqlclient 2.1 becomes the minimal mysqlclient supported by Plain.
244        # See https://github.com/PyMySQL/mysqlclient/issues/489
245        if connection.encoders.get(bytes) is bytes:
246            connection.encoders.pop(bytes)
247        return connection
248
249    def init_connection_state(self) -> None:
250        super().init_connection_state()
251        assignments = []
252        if self.features.is_sql_auto_is_null_enabled:
253            # SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on
254            # a recently inserted row will return when the field is tested
255            # for NULL. Disabling this brings this aspect of MySQL in line
256            # with SQL standards.
257            assignments.append("SET SQL_AUTO_IS_NULL = 0")
258
259        if self.isolation_level:
260            assignments.append(
261                f"SET SESSION TRANSACTION ISOLATION LEVEL {self.isolation_level.upper()}"
262            )
263
264        if assignments:
265            with self.cursor() as cursor:
266                cursor.execute("; ".join(assignments))
267
268    def create_cursor(self, name: str | None = None) -> CursorWrapper:
269        cursor = self.connection.cursor()
270        return CursorWrapper(cursor)
271
272    def _rollback(self) -> None:
273        try:
274            BaseDatabaseWrapper._rollback(self)
275        except Database.NotSupportedError:
276            pass
277
278    def _set_autocommit(self, autocommit: bool) -> None:
279        with self.wrap_database_errors:
280            self.connection.autocommit(autocommit)
281
282    def check_constraints(self, table_names: list[str] | None = None) -> None:
283        """Check ``table_names`` for rows with invalid foreign key references."""
284        with self.cursor() as cursor:
285            if table_names is None:
286                table_names = self.introspection.table_names(cursor)
287            for table_name in table_names:
288                primary_key_column_name = self.introspection.get_primary_key_column(
289                    cursor, table_name
290                )
291                if not primary_key_column_name:
292                    continue
293                relations = self.introspection.get_relations(cursor, table_name)
294                for column_name, (
295                    referenced_column_name,
296                    referenced_table_name,
297                ) in relations.items():
298                    cursor.execute(
299                        f"""
300                        SELECT REFERRING.`{primary_key_column_name}`, REFERRING.`{column_name}` FROM `{table_name}` as REFERRING
301                        LEFT JOIN `{referenced_table_name}` as REFERRED
302                        ON (REFERRING.`{column_name}` = REFERRED.`{referenced_column_name}`)
303                        WHERE REFERRING.`{column_name}` IS NOT NULL AND REFERRED.`{referenced_column_name}` IS NULL
304                        """
305                    )
306                    for bad_row in cursor.fetchall():
307                        raise IntegrityError(
308                            f"The row in table '{table_name}' with primary key '{bad_row[0]}' has an "
309                            f"invalid foreign key: {table_name}.{column_name} contains a value '{bad_row[1]}' that "
310                            f"does not have a corresponding value in {referenced_table_name}.{referenced_column_name}."
311                        )
312
313    def is_usable(self) -> bool:
314        try:
315            self.connection.ping()
316        except Database.Error:
317            return False
318        else:
319            return True
320
321    @cached_property
322    def display_name(self) -> str:
323        return "MariaDB" if self.mysql_is_mariadb else "MySQL"
324
325    @cached_property
326    def data_type_check_constraints(self) -> dict[str, str]:
327        if self.features.supports_column_check_constraints:
328            check_constraints = {
329                "PositiveBigIntegerField": "`%(column)s` >= 0",
330                "PositiveIntegerField": "`%(column)s` >= 0",
331                "PositiveSmallIntegerField": "`%(column)s` >= 0",
332            }
333            if self.mysql_is_mariadb and self.mysql_version < (10, 4, 3):
334                # MariaDB < 10.4.3 doesn't automatically use the JSON_VALID as
335                # a check constraint.
336                check_constraints["JSONField"] = "JSON_VALID(`%(column)s`)"
337            return check_constraints
338        return {}
339
340    @cached_property
341    def mysql_server_data(self) -> dict[str, Any]:
342        with self.temporary_connection() as cursor:
343            # Select some server variables and test if the time zone
344            # definitions are installed. CONVERT_TZ returns NULL if 'UTC'
345            # timezone isn't loaded into the mysql.time_zone table.
346            cursor.execute(
347                """
348                SELECT VERSION(),
349                       @@sql_mode,
350                       @@default_storage_engine,
351                       @@sql_auto_is_null,
352                       @@lower_case_table_names,
353                       CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL
354            """
355            )
356            row = cursor.fetchone()
357        return {
358            "version": row[0],
359            "sql_mode": row[1],
360            "default_storage_engine": row[2],
361            "sql_auto_is_null": bool(row[3]),
362            "lower_case_table_names": bool(row[4]),
363            "has_zoneinfo_database": bool(row[5]),
364        }
365
366    @cached_property
367    def mysql_server_info(self) -> str:
368        return self.mysql_server_data["version"]
369
370    @cached_property
371    def mysql_version(self) -> tuple[int, int, int]:
372        match = server_version_re.match(self.mysql_server_info)
373        if not match:
374            raise Exception(
375                f"Unable to determine MySQL version from version string {self.mysql_server_info!r}"
376            )
377        return tuple(int(x) for x in match.groups())
378
379    @cached_property
380    def mysql_is_mariadb(self) -> bool:
381        return "mariadb" in self.mysql_server_info.lower()
382
383    @cached_property
384    def sql_mode(self) -> set[str]:
385        sql_mode = self.mysql_server_data["sql_mode"]
386        return set(sql_mode.split(",") if sql_mode else ())