Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2MySQL database backend for Plain.
  3
  4Requires mysqlclient: https://pypi.org/project/mysqlclient/
  5"""
  6
  7from __future__ import annotations
  8
  9from functools import cached_property
 10from typing import TYPE_CHECKING, Any, cast
 11
 12import MySQLdb as Database
 13from MySQLdb.constants import CLIENT, FIELD_TYPE
 14from MySQLdb.converters import conversions
 15
 16from plain.exceptions import ImproperlyConfigured
 17from plain.models.backends import utils as backend_utils
 18from plain.models.backends.base.base import BaseDatabaseWrapper
 19from plain.models.db import IntegrityError
 20from plain.utils.regex_helper import _lazy_re_compile
 21
 22# Type checkers always see the proper type; runtime falls back to Any if needed
 23if TYPE_CHECKING:
 24    from MySQLdb.connections import Connection as MySQLConnection
 25else:
 26    try:
 27        from MySQLdb.connections import Connection as MySQLConnection
 28    except ImportError:
 29        MySQLConnection = Any
 30
 31from .client import DatabaseClient
 32from .creation import DatabaseCreation
 33from .features import DatabaseFeatures
 34from .introspection import DatabaseIntrospection
 35from .operations import DatabaseOperations
 36from .schema import DatabaseSchemaEditor
 37from .validation import DatabaseValidation
 38
 39# MySQLdb returns TIME columns as timedelta -- they are more like timedelta in
 40# terms of actual behavior as they are signed and include days -- and Plain
 41# expects time.
 42plain_conversions = {
 43    **conversions,
 44    **{FIELD_TYPE.TIME: backend_utils.typecast_time},
 45}
 46
 47# This should match the numerical portion of the version numbers (we can treat
 48# versions like 5.0.24 and 5.0.24a as the same).
 49server_version_re = _lazy_re_compile(r"(\d{1,2})\.(\d{1,2})\.(\d{1,2})")
 50
 51
 52class CursorWrapper:
 53    """
 54    A thin wrapper around MySQLdb's normal cursor class that catches particular
 55    exception instances and reraises them with the correct types.
 56
 57    Implemented as a wrapper, rather than a subclass, so that it isn't stuck
 58    to the particular underlying representation returned by Connection.cursor().
 59    """
 60
 61    codes_for_integrityerror = (
 62        1048,  # Column cannot be null
 63        1690,  # BIGINT UNSIGNED value is out of range
 64        3819,  # CHECK constraint is violated
 65        4025,  # CHECK constraint failed
 66    )
 67
 68    def __init__(self, cursor: Any) -> None:
 69        self.cursor = cursor
 70
 71    def execute(self, query: str, args: Any = None) -> int:
 72        try:
 73            # args is None means no string interpolation
 74            return self.cursor.execute(query, args)
 75        except Database.OperationalError as e:
 76            # Map some error codes to IntegrityError, since they seem to be
 77            # misclassified and Plain would prefer the more logical place.
 78            if e.args[0] in self.codes_for_integrityerror:
 79                raise IntegrityError(*tuple(e.args))
 80            raise
 81
 82    def executemany(self, query: str, args: Any) -> int:
 83        try:
 84            return self.cursor.executemany(query, args)
 85        except Database.OperationalError as e:
 86            # Map some error codes to IntegrityError, since they seem to be
 87            # misclassified and Plain would prefer the more logical place.
 88            if e.args[0] in self.codes_for_integrityerror:
 89                raise IntegrityError(*tuple(e.args))
 90            raise
 91
 92    def __getattr__(self, attr: str) -> Any:
 93        return getattr(self.cursor, attr)
 94
 95    def __iter__(self) -> Any:
 96        return iter(self.cursor)
 97
 98
 99class MySQLDatabaseWrapper(BaseDatabaseWrapper):
100    # Type checker hints: narrow base class attribute types to backend-specific classes
101    ops: DatabaseOperations
102    features: DatabaseFeatures
103    introspection: DatabaseIntrospection
104    creation: DatabaseCreation
105
106    vendor = "mysql"
107    # This dictionary maps Field objects to their associated MySQL column
108    # types, as strings. Column-type strings can contain format strings; they'll
109    # be interpolated against the values of Field.__dict__ before being output.
110    # If a column type is set to None, it won't be included in the output.
111    data_types = {
112        "PrimaryKeyField": "bigint AUTO_INCREMENT",
113        "BinaryField": "longblob",
114        "BooleanField": "bool",
115        "CharField": "varchar(%(max_length)s)",
116        "DateField": "date",
117        "DateTimeField": "datetime(6)",
118        "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
119        "DurationField": "bigint",
120        "FloatField": "double precision",
121        "IntegerField": "integer",
122        "BigIntegerField": "bigint",
123        "GenericIPAddressField": "char(39)",
124        "JSONField": "json",
125        "PositiveBigIntegerField": "bigint UNSIGNED",
126        "PositiveIntegerField": "integer UNSIGNED",
127        "PositiveSmallIntegerField": "smallint UNSIGNED",
128        "SmallIntegerField": "smallint",
129        "TextField": "longtext",
130        "TimeField": "time(6)",
131        "UUIDField": "char(32)",
132    }
133
134    # For these data types:
135    # - MySQL < 8.0.13 doesn't accept default values and implicitly treats them
136    #   as nullable
137    # - all versions of MySQL and MariaDB don't support full width database
138    #   indexes
139    _limited_data_types = (
140        "tinyblob",
141        "blob",
142        "mediumblob",
143        "longblob",
144        "tinytext",
145        "text",
146        "mediumtext",
147        "longtext",
148        "json",
149    )
150
151    operators = {
152        "exact": "= %s",
153        "iexact": "LIKE %s",
154        "contains": "LIKE BINARY %s",
155        "icontains": "LIKE %s",
156        "gt": "> %s",
157        "gte": ">= %s",
158        "lt": "< %s",
159        "lte": "<= %s",
160        "startswith": "LIKE BINARY %s",
161        "endswith": "LIKE BINARY %s",
162        "istartswith": "LIKE %s",
163        "iendswith": "LIKE %s",
164    }
165
166    # The patterns below are used to generate SQL pattern lookup clauses when
167    # the right-hand side of the lookup isn't a raw string (it might be an expression
168    # or the result of a bilateral transformation).
169    # In those cases, special characters for LIKE operators (e.g. \, *, _) should be
170    # escaped on database side.
171    #
172    # Note: we use str.format() here for readability as '%' is used as a wildcard for
173    # the LIKE operator.
174    pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')"
175    pattern_ops = {
176        "contains": "LIKE BINARY CONCAT('%%', {}, '%%')",
177        "icontains": "LIKE CONCAT('%%', {}, '%%')",
178        "startswith": "LIKE BINARY CONCAT({}, '%%')",
179        "istartswith": "LIKE CONCAT({}, '%%')",
180        "endswith": "LIKE BINARY CONCAT('%%', {})",
181        "iendswith": "LIKE CONCAT('%%', {})",
182    }
183
184    isolation_levels = {
185        "read uncommitted",
186        "read committed",
187        "repeatable read",
188        "serializable",
189    }
190
191    Database = Database
192    SchemaEditorClass = DatabaseSchemaEditor
193    # Classes instantiated in __init__().
194    client_class = DatabaseClient
195    creation_class = DatabaseCreation
196    features_class = DatabaseFeatures
197    introspection_class = DatabaseIntrospection
198    ops_class = DatabaseOperations
199    validation_class = DatabaseValidation
200
201    def get_database_version(self) -> tuple[int, int, int]:
202        return self.mysql_version
203
204    def get_connection_params(self) -> dict[str, Any]:
205        kwargs: dict[str, Any] = {
206            "conv": plain_conversions,
207            "charset": "utf8",
208        }
209        settings_dict = self.settings_dict
210        if settings_dict["USER"]:
211            kwargs["user"] = settings_dict["USER"]
212        if settings_dict["NAME"]:
213            kwargs["database"] = settings_dict["NAME"]
214        if settings_dict["PASSWORD"]:
215            kwargs["password"] = settings_dict["PASSWORD"]
216        if settings_dict["HOST"].startswith("/"):
217            kwargs["unix_socket"] = settings_dict["HOST"]
218        elif settings_dict["HOST"]:
219            kwargs["host"] = settings_dict["HOST"]
220        if settings_dict["PORT"]:
221            kwargs["port"] = int(settings_dict["PORT"])
222        # We need the number of potentially affected rows after an
223        # "UPDATE", not the number of changed rows.
224        kwargs["client_flag"] = CLIENT.FOUND_ROWS
225        # Validate the transaction isolation level, if specified.
226        options = settings_dict.get("OPTIONS", {}).copy()
227        isolation_level = options.pop("isolation_level", "read committed")
228        if isolation_level:
229            isolation_level = isolation_level.lower()
230            if isolation_level not in self.isolation_levels:
231                raise ImproperlyConfigured(
232                    "Invalid transaction isolation level '{}' specified.\n"
233                    "Use one of {}, or None.".format(
234                        isolation_level,
235                        ", ".join(f"'{s}'" for s in sorted(self.isolation_levels)),
236                    )
237                )
238        self.isolation_level = isolation_level
239        kwargs.update(options)
240        return kwargs
241
242    def get_new_connection(self, conn_params: dict[str, Any]) -> MySQLConnection:
243        connection = Database.connect(**conn_params)
244        # bytes encoder in mysqlclient doesn't work and was added only to
245        # prevent KeyErrors in Plain < 2.0. We can remove this workaround when
246        # mysqlclient 2.1 becomes the minimal mysqlclient supported by Plain.
247        # See https://github.com/PyMySQL/mysqlclient/issues/489
248        if connection.encoders.get(bytes) is bytes:
249            connection.encoders.pop(bytes)
250        return connection
251
252    def init_connection_state(self) -> None:
253        super().init_connection_state()
254        assignments = []
255        if self.features.is_sql_auto_is_null_enabled:
256            # SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on
257            # a recently inserted row will return when the field is tested
258            # for NULL. Disabling this brings this aspect of MySQL in line
259            # with SQL standards.
260            assignments.append("SET SQL_AUTO_IS_NULL = 0")
261
262        if self.isolation_level:
263            assignments.append(
264                f"SET SESSION TRANSACTION ISOLATION LEVEL {self.isolation_level.upper()}"
265            )
266
267        if assignments:
268            with self.cursor() as cursor:
269                cursor.execute("; ".join(assignments))
270
271    def create_cursor(self, name: str | None = None) -> CursorWrapper:
272        cursor = self.connection.cursor()
273        return CursorWrapper(cursor)
274
275    def _rollback(self) -> None:
276        try:
277            BaseDatabaseWrapper._rollback(self)
278        except Database.NotSupportedError:
279            pass
280
281    def _set_autocommit(self, autocommit: bool) -> None:
282        with self.wrap_database_errors:
283            self.connection.autocommit(autocommit)
284
285    def check_constraints(self, table_names: list[str] | None = None) -> None:
286        """Check ``table_names`` for rows with invalid foreign key references."""
287        with self.cursor() as cursor:
288            if table_names is None:
289                table_names = self.introspection.table_names(cursor)
290            for table_name in table_names:
291                primary_key_column_name = self.introspection.get_primary_key_column(
292                    cursor, table_name
293                )
294                if not primary_key_column_name:
295                    continue
296                relations = self.introspection.get_relations(cursor, table_name)
297                for column_name, (
298                    referenced_column_name,
299                    referenced_table_name,
300                ) in relations.items():
301                    cursor.execute(
302                        f"""
303                        SELECT REFERRING.`{primary_key_column_name}`, REFERRING.`{column_name}` FROM `{table_name}` as REFERRING
304                        LEFT JOIN `{referenced_table_name}` as REFERRED
305                        ON (REFERRING.`{column_name}` = REFERRED.`{referenced_column_name}`)
306                        WHERE REFERRING.`{column_name}` IS NOT NULL AND REFERRED.`{referenced_column_name}` IS NULL
307                        """
308                    )
309                    for bad_row in cursor.fetchall():
310                        raise IntegrityError(
311                            f"The row in table '{table_name}' with primary key '{bad_row[0]}' has an "
312                            f"invalid foreign key: {table_name}.{column_name} contains a value '{bad_row[1]}' that "
313                            f"does not have a corresponding value in {referenced_table_name}.{referenced_column_name}."
314                        )
315
316    def is_usable(self) -> bool:
317        try:
318            self.connection.ping()
319        except Database.Error:
320            return False
321        else:
322            return True
323
324    @cached_property
325    def display_name(self) -> str:
326        return "MariaDB" if self.mysql_is_mariadb else "MySQL"
327
328    @cached_property
329    def data_type_check_constraints(self) -> dict[str, str]:
330        if self.features.supports_column_check_constraints:
331            check_constraints = {
332                "PositiveBigIntegerField": "`%(column)s` >= 0",
333                "PositiveIntegerField": "`%(column)s` >= 0",
334                "PositiveSmallIntegerField": "`%(column)s` >= 0",
335            }
336            if self.mysql_is_mariadb and self.mysql_version < (10, 4, 3):
337                # MariaDB < 10.4.3 doesn't automatically use the JSON_VALID as
338                # a check constraint.
339                check_constraints["JSONField"] = "JSON_VALID(`%(column)s`)"
340            return check_constraints
341        return {}
342
343    @cached_property
344    def mysql_server_data(self) -> dict[str, Any]:
345        with self.temporary_connection() as cursor:
346            # Select some server variables and test if the time zone
347            # definitions are installed. CONVERT_TZ returns NULL if 'UTC'
348            # timezone isn't loaded into the mysql.time_zone table.
349            cursor.execute(
350                """
351                SELECT VERSION(),
352                       @@sql_mode,
353                       @@default_storage_engine,
354                       @@sql_auto_is_null,
355                       @@lower_case_table_names,
356                       CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL
357            """
358            )
359            row = cursor.fetchone()
360        return {
361            "version": row[0],
362            "sql_mode": row[1],
363            "default_storage_engine": row[2],
364            "sql_auto_is_null": bool(row[3]),
365            "lower_case_table_names": bool(row[4]),
366            "has_zoneinfo_database": bool(row[5]),
367        }
368
369    @cached_property
370    def mysql_server_info(self) -> str:
371        return self.mysql_server_data["version"]
372
373    @cached_property
374    def mysql_version(self) -> tuple[int, int, int]:
375        match = server_version_re.match(self.mysql_server_info)
376        if not match:
377            raise Exception(
378                f"Unable to determine MySQL version from version string {self.mysql_server_info!r}"
379            )
380        return cast(tuple[int, int, int], tuple(int(x) for x in match.groups()))
381
382    @cached_property
383    def mysql_is_mariadb(self) -> bool:
384        return "mariadb" in self.mysql_server_info.lower()
385
386    @cached_property
387    def sql_mode(self) -> set[str]:
388        sql_mode = self.mysql_server_data["sql_mode"]
389        return set(sql_mode.split(",") if sql_mode else ())