Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2MySQL database backend for Plain.
  3
  4Requires mysqlclient: https://pypi.org/project/mysqlclient/
  5"""
  6
  7from __future__ import annotations
  8
  9from functools import cached_property
 10from typing import TYPE_CHECKING, Any, cast
 11
 12import MySQLdb as Database
 13from MySQLdb.constants import CLIENT, FIELD_TYPE
 14from MySQLdb.converters import conversions
 15
 16from plain.exceptions import ImproperlyConfigured
 17from plain.models.backends import utils as backend_utils
 18from plain.models.backends.base.base import BaseDatabaseWrapper
 19from plain.models.db import IntegrityError
 20from plain.utils.regex_helper import _lazy_re_compile
 21
 22# Type checkers always see the proper type; runtime falls back to Any if needed
 23if TYPE_CHECKING:
 24    from MySQLdb.connections import Connection as MySQLConnection
 25else:
 26    try:
 27        from MySQLdb.connections import Connection as MySQLConnection
 28    except ImportError:
 29        MySQLConnection = Any  # type: ignore[misc]
 30
 31from .client import DatabaseClient
 32from .creation import DatabaseCreation
 33from .features import DatabaseFeatures
 34from .introspection import DatabaseIntrospection
 35from .operations import DatabaseOperations
 36from .schema import DatabaseSchemaEditor
 37from .validation import DatabaseValidation
 38
 39# MySQLdb returns TIME columns as timedelta -- they are more like timedelta in
 40# terms of actual behavior as they are signed and include days -- and Plain
 41# expects time.
 42plain_conversions = {
 43    **conversions,
 44    **{FIELD_TYPE.TIME: backend_utils.typecast_time},
 45}
 46
 47# This should match the numerical portion of the version numbers (we can treat
 48# versions like 5.0.24 and 5.0.24a as the same).
 49server_version_re = _lazy_re_compile(r"(\d{1,2})\.(\d{1,2})\.(\d{1,2})")
 50
 51
 52class CursorWrapper:
 53    """
 54    A thin wrapper around MySQLdb's normal cursor class that catches particular
 55    exception instances and reraises them with the correct types.
 56
 57    Implemented as a wrapper, rather than a subclass, so that it isn't stuck
 58    to the particular underlying representation returned by Connection.cursor().
 59    """
 60
 61    codes_for_integrityerror = (
 62        1048,  # Column cannot be null
 63        1690,  # BIGINT UNSIGNED value is out of range
 64        3819,  # CHECK constraint is violated
 65        4025,  # CHECK constraint failed
 66    )
 67
 68    def __init__(self, cursor: Any) -> None:
 69        self.cursor = cursor
 70
 71    def execute(self, query: str, args: Any = None) -> int:
 72        try:
 73            # args is None means no string interpolation
 74            return self.cursor.execute(query, args)
 75        except Database.OperationalError as e:
 76            # Map some error codes to IntegrityError, since they seem to be
 77            # misclassified and Plain would prefer the more logical place.
 78            if e.args[0] in self.codes_for_integrityerror:
 79                raise IntegrityError(*tuple(e.args))
 80            raise
 81
 82    def executemany(self, query: str, args: Any) -> int:
 83        try:
 84            return self.cursor.executemany(query, args)
 85        except Database.OperationalError as e:
 86            # Map some error codes to IntegrityError, since they seem to be
 87            # misclassified and Plain would prefer the more logical place.
 88            if e.args[0] in self.codes_for_integrityerror:
 89                raise IntegrityError(*tuple(e.args))
 90            raise
 91
 92    def __getattr__(self, attr: str) -> Any:
 93        return getattr(self.cursor, attr)
 94
 95    def __iter__(self) -> Any:
 96        return iter(self.cursor)
 97
 98
 99class MySQLDatabaseWrapper(BaseDatabaseWrapper):
100    # Type checker hints: narrow base class attribute types to backend-specific classes
101    ops: DatabaseOperations
102    features: DatabaseFeatures
103    introspection: DatabaseIntrospection
104    creation: DatabaseCreation
105
106    vendor = "mysql"
107    # This dictionary maps Field objects to their associated MySQL column
108    # types, as strings. Column-type strings can contain format strings; they'll
109    # be interpolated against the values of Field.__dict__ before being output.
110    # If a column type is set to None, it won't be included in the output.
111    data_types = {
112        "PrimaryKeyField": "bigint AUTO_INCREMENT",
113        "BinaryField": "longblob",
114        "BooleanField": "bool",
115        "CharField": "varchar(%(max_length)s)",
116        "DateField": "date",
117        "DateTimeField": "datetime(6)",
118        "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
119        "DurationField": "bigint",
120        "FloatField": "double precision",
121        "IntegerField": "integer",
122        "BigIntegerField": "bigint",
123        "GenericIPAddressField": "char(39)",
124        "JSONField": "json",
125        "PositiveBigIntegerField": "bigint UNSIGNED",
126        "PositiveIntegerField": "integer UNSIGNED",
127        "PositiveSmallIntegerField": "smallint UNSIGNED",
128        "SmallIntegerField": "smallint",
129        "TextField": "longtext",
130        "TimeField": "time(6)",
131        "UUIDField": "char(32)",
132    }
133
134    # For these data types:
135    # - MySQL < 8.0.13 doesn't accept default values and implicitly treats them
136    #   as nullable
137    # - all versions of MySQL and MariaDB don't support full width database
138    #   indexes
139    _limited_data_types = (
140        "tinyblob",
141        "blob",
142        "mediumblob",
143        "longblob",
144        "tinytext",
145        "text",
146        "mediumtext",
147        "longtext",
148        "json",
149    )
150
151    operators = {
152        "exact": "= %s",
153        "iexact": "LIKE %s",
154        "contains": "LIKE BINARY %s",
155        "icontains": "LIKE %s",
156        "gt": "> %s",
157        "gte": ">= %s",
158        "lt": "< %s",
159        "lte": "<= %s",
160        "startswith": "LIKE BINARY %s",
161        "endswith": "LIKE BINARY %s",
162        "istartswith": "LIKE %s",
163        "iendswith": "LIKE %s",
164    }
165
166    # The patterns below are used to generate SQL pattern lookup clauses when
167    # the right-hand side of the lookup isn't a raw string (it might be an expression
168    # or the result of a bilateral transformation).
169    # In those cases, special characters for LIKE operators (e.g. \, *, _) should be
170    # escaped on database side.
171    #
172    # Note: we use str.format() here for readability as '%' is used as a wildcard for
173    # the LIKE operator.
174    pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')"
175    pattern_ops = {
176        "contains": "LIKE BINARY CONCAT('%%', {}, '%%')",
177        "icontains": "LIKE CONCAT('%%', {}, '%%')",
178        "startswith": "LIKE BINARY CONCAT({}, '%%')",
179        "istartswith": "LIKE CONCAT({}, '%%')",
180        "endswith": "LIKE BINARY CONCAT('%%', {})",
181        "iendswith": "LIKE CONCAT('%%', {})",
182    }
183
184    isolation_levels = {
185        "read uncommitted",
186        "read committed",
187        "repeatable read",
188        "serializable",
189    }
190
191    Database = Database
192    SchemaEditorClass = DatabaseSchemaEditor
193    # Classes instantiated in __init__().
194    client_class = DatabaseClient
195    creation_class = DatabaseCreation
196    features_class = DatabaseFeatures
197    introspection_class = DatabaseIntrospection
198    ops_class = DatabaseOperations
199    validation_class = DatabaseValidation
200
201    def get_database_version(self) -> tuple[int, int, int]:
202        return self.mysql_version
203
204    def get_connection_params(self) -> dict[str, Any]:
205        kwargs: dict[str, Any] = {
206            "conv": plain_conversions,
207            "charset": "utf8",
208        }
209        settings_dict = self.settings_dict
210        if settings_dict["USER"]:
211            kwargs["user"] = settings_dict["USER"]
212        if settings_dict["NAME"]:
213            kwargs["database"] = settings_dict["NAME"]
214        if settings_dict["PASSWORD"]:
215            kwargs["password"] = settings_dict["PASSWORD"]
216        if settings_dict["HOST"].startswith("/"):
217            kwargs["unix_socket"] = settings_dict["HOST"]
218        elif settings_dict["HOST"]:
219            kwargs["host"] = settings_dict["HOST"]
220        if settings_dict["PORT"]:
221            kwargs["port"] = int(settings_dict["PORT"])
222        # We need the number of potentially affected rows after an
223        # "UPDATE", not the number of changed rows.
224        kwargs["client_flag"] = CLIENT.FOUND_ROWS
225        # Validate the transaction isolation level, if specified.
226        options = settings_dict["OPTIONS"].copy()
227        isolation_level = options.pop("isolation_level", "read committed")
228        if isolation_level:
229            isolation_level = isolation_level.lower()
230            if isolation_level not in self.isolation_levels:
231                raise ImproperlyConfigured(
232                    "Invalid transaction isolation level '{}' specified.\n"
233                    "Use one of {}, or None.".format(
234                        isolation_level,
235                        ", ".join(f"'{s}'" for s in sorted(self.isolation_levels)),
236                    )
237                )
238        self.isolation_level = isolation_level
239        kwargs.update(options)
240        return kwargs
241
242    def get_new_connection(self, conn_params: dict[str, Any]) -> MySQLConnection:
243        connection = Database.connect(**conn_params)
244        # bytes encoder in mysqlclient doesn't work and was added only to
245        # prevent KeyErrors in Plain < 2.0. We can remove this workaround when
246        # mysqlclient 2.1 becomes the minimal mysqlclient supported by Plain.
247        # See https://github.com/PyMySQL/mysqlclient/issues/489
248        if connection.encoders.get(bytes) is bytes:
249            connection.encoders.pop(bytes)
250        return connection
251
252    def init_connection_state(self) -> None:
253        super().init_connection_state()
254        assignments = []
255        if self.features.is_sql_auto_is_null_enabled:
256            # SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on
257            # a recently inserted row will return when the field is tested
258            # for NULL. Disabling this brings this aspect of MySQL in line
259            # with SQL standards.
260            assignments.append("SET SQL_AUTO_IS_NULL = 0")
261
262        if self.isolation_level:
263            assignments.append(
264                f"SET SESSION TRANSACTION ISOLATION LEVEL {self.isolation_level.upper()}"
265            )
266
267        if assignments:
268            with self.cursor() as cursor:
269                cursor.execute("; ".join(assignments))
270
271    def create_cursor(self, name: str | None = None) -> CursorWrapper:
272        cursor = self.connection.cursor()
273        return CursorWrapper(cursor)
274
275    def _rollback(self) -> None:
276        try:
277            BaseDatabaseWrapper._rollback(self)
278        except Database.NotSupportedError:
279            pass
280
281    def _set_autocommit(self, autocommit: bool) -> None:
282        with self.wrap_database_errors:
283            self.connection.autocommit(autocommit)
284
285    def check_constraints(self, table_names: list[str] | None = None) -> None:
286        """Check ``table_names`` for rows with invalid foreign key references."""
287        with self.cursor() as cursor:
288            if table_names is None:
289                table_names = self.introspection.table_names(cursor)
290            for table_name in table_names:
291                primary_key_column_name = self.introspection.get_primary_key_column(
292                    cursor, table_name
293                )
294                if not primary_key_column_name:
295                    continue
296                relations = self.introspection.get_relations(cursor, table_name)
297                for column_name, (
298                    referenced_column_name,
299                    referenced_table_name,
300                ) in relations.items():
301                    cursor.execute(
302                        f"""
303                        SELECT REFERRING.`{primary_key_column_name}`, REFERRING.`{column_name}` FROM `{table_name}` as REFERRING
304                        LEFT JOIN `{referenced_table_name}` as REFERRED
305                        ON (REFERRING.`{column_name}` = REFERRED.`{referenced_column_name}`)
306                        WHERE REFERRING.`{column_name}` IS NOT NULL AND REFERRED.`{referenced_column_name}` IS NULL
307                        """
308                    )
309                    for bad_row in cursor.fetchall():
310                        raise IntegrityError(
311                            f"The row in table '{table_name}' with primary key '{bad_row[0]}' has an "
312                            f"invalid foreign key: {table_name}.{column_name} contains a value '{bad_row[1]}' that "
313                            f"does not have a corresponding value in {referenced_table_name}.{referenced_column_name}."
314                        )
315
316    def is_usable(self) -> bool:
317        try:
318            self.connection.ping()
319        except Database.Error:
320            return False
321        else:
322            return True
323
324    @cached_property
325    def display_name(self) -> str:
326        return "MariaDB" if self.mysql_is_mariadb else "MySQL"
327
328    @cached_property
329    def data_type_check_constraints(self) -> dict[str, str]:
330        if self.features.supports_column_check_constraints:
331            check_constraints = {
332                "PositiveBigIntegerField": "`%(column)s` >= 0",
333                "PositiveIntegerField": "`%(column)s` >= 0",
334                "PositiveSmallIntegerField": "`%(column)s` >= 0",
335            }
336            if self.mysql_is_mariadb and self.mysql_version < (10, 4, 3):
337                # MariaDB < 10.4.3 doesn't automatically use the JSON_VALID as
338                # a check constraint.
339                check_constraints["JSONField"] = "JSON_VALID(`%(column)s`)"
340            return check_constraints
341        return {}
342
343    @cached_property
344    def mysql_server_data(self) -> dict[str, Any]:
345        with self.temporary_connection() as cursor:
346            # Select some server variables and test if the time zone
347            # definitions are installed. CONVERT_TZ returns NULL if 'UTC'
348            # timezone isn't loaded into the mysql.time_zone table.
349            cursor.execute(
350                """
351                SELECT VERSION(),
352                       @@sql_mode,
353                       @@default_storage_engine,
354                       @@sql_auto_is_null,
355                       @@lower_case_table_names,
356                       CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL
357            """
358            )
359            row = cursor.fetchone()
360        return {
361            "version": row[0],
362            "sql_mode": row[1],
363            "default_storage_engine": row[2],
364            "sql_auto_is_null": bool(row[3]),
365            "lower_case_table_names": bool(row[4]),
366            "has_zoneinfo_database": bool(row[5]),
367        }
368
369    @cached_property
370    def mysql_server_info(self) -> str:
371        return self.mysql_server_data["version"]
372
373    @cached_property
374    def mysql_version(self) -> tuple[int, int, int]:
375        match = server_version_re.match(self.mysql_server_info)
376        if not match:
377            raise Exception(
378                f"Unable to determine MySQL version from version string {self.mysql_server_info!r}"
379            )
380        return cast(tuple[int, int, int], tuple(int(x) for x in match.groups()))
381
382    @cached_property
383    def mysql_is_mariadb(self) -> bool:
384        return "mariadb" in self.mysql_server_info.lower()
385
386    @cached_property
387    def sql_mode(self) -> set[str]:
388        sql_mode = self.mysql_server_data["sql_mode"]
389        return set(sql_mode.split(",") if sql_mode else ())