Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2MySQL database backend for Plain.
  3
  4Requires mysqlclient: https://pypi.org/project/mysqlclient/
  5"""
  6
  7from functools import cached_property
  8
  9import MySQLdb as Database
 10from MySQLdb.constants import CLIENT, FIELD_TYPE
 11from MySQLdb.converters import conversions
 12
 13from plain.exceptions import ImproperlyConfigured
 14from plain.models.backends import utils as backend_utils
 15from plain.models.backends.base.base import BaseDatabaseWrapper
 16from plain.models.db import IntegrityError
 17from plain.utils.regex_helper import _lazy_re_compile
 18
 19from .client import DatabaseClient
 20from .creation import DatabaseCreation
 21from .features import DatabaseFeatures
 22from .introspection import DatabaseIntrospection
 23from .operations import DatabaseOperations
 24from .schema import DatabaseSchemaEditor
 25from .validation import DatabaseValidation
 26
 27# MySQLdb returns TIME columns as timedelta -- they are more like timedelta in
 28# terms of actual behavior as they are signed and include days -- and Plain
 29# expects time.
 30plain_conversions = {
 31    **conversions,
 32    **{FIELD_TYPE.TIME: backend_utils.typecast_time},
 33}
 34
 35# This should match the numerical portion of the version numbers (we can treat
 36# versions like 5.0.24 and 5.0.24a as the same).
 37server_version_re = _lazy_re_compile(r"(\d{1,2})\.(\d{1,2})\.(\d{1,2})")
 38
 39
 40class CursorWrapper:
 41    """
 42    A thin wrapper around MySQLdb's normal cursor class that catches particular
 43    exception instances and reraises them with the correct types.
 44
 45    Implemented as a wrapper, rather than a subclass, so that it isn't stuck
 46    to the particular underlying representation returned by Connection.cursor().
 47    """
 48
 49    codes_for_integrityerror = (
 50        1048,  # Column cannot be null
 51        1690,  # BIGINT UNSIGNED value is out of range
 52        3819,  # CHECK constraint is violated
 53        4025,  # CHECK constraint failed
 54    )
 55
 56    def __init__(self, cursor):
 57        self.cursor = cursor
 58
 59    def execute(self, query, args=None):
 60        try:
 61            # args is None means no string interpolation
 62            return self.cursor.execute(query, args)
 63        except Database.OperationalError as e:
 64            # Map some error codes to IntegrityError, since they seem to be
 65            # misclassified and Plain would prefer the more logical place.
 66            if e.args[0] in self.codes_for_integrityerror:
 67                raise IntegrityError(*tuple(e.args))
 68            raise
 69
 70    def executemany(self, query, args):
 71        try:
 72            return self.cursor.executemany(query, args)
 73        except Database.OperationalError as e:
 74            # Map some error codes to IntegrityError, since they seem to be
 75            # misclassified and Plain would prefer the more logical place.
 76            if e.args[0] in self.codes_for_integrityerror:
 77                raise IntegrityError(*tuple(e.args))
 78            raise
 79
 80    def __getattr__(self, attr):
 81        return getattr(self.cursor, attr)
 82
 83    def __iter__(self):
 84        return iter(self.cursor)
 85
 86
 87class DatabaseWrapper(BaseDatabaseWrapper):
 88    vendor = "mysql"
 89    # This dictionary maps Field objects to their associated MySQL column
 90    # types, as strings. Column-type strings can contain format strings; they'll
 91    # be interpolated against the values of Field.__dict__ before being output.
 92    # If a column type is set to None, it won't be included in the output.
 93    data_types = {
 94        "AutoField": "integer AUTO_INCREMENT",
 95        "BigAutoField": "bigint AUTO_INCREMENT",
 96        "BinaryField": "longblob",
 97        "BooleanField": "bool",
 98        "CharField": "varchar(%(max_length)s)",
 99        "DateField": "date",
100        "DateTimeField": "datetime(6)",
101        "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
102        "DurationField": "bigint",
103        "FloatField": "double precision",
104        "IntegerField": "integer",
105        "BigIntegerField": "bigint",
106        "IPAddressField": "char(15)",
107        "GenericIPAddressField": "char(39)",
108        "JSONField": "json",
109        "PositiveBigIntegerField": "bigint UNSIGNED",
110        "PositiveIntegerField": "integer UNSIGNED",
111        "PositiveSmallIntegerField": "smallint UNSIGNED",
112        "SmallAutoField": "smallint AUTO_INCREMENT",
113        "SmallIntegerField": "smallint",
114        "TextField": "longtext",
115        "TimeField": "time(6)",
116        "UUIDField": "char(32)",
117    }
118
119    # For these data types:
120    # - MySQL < 8.0.13 doesn't accept default values and implicitly treats them
121    #   as nullable
122    # - all versions of MySQL and MariaDB don't support full width database
123    #   indexes
124    _limited_data_types = (
125        "tinyblob",
126        "blob",
127        "mediumblob",
128        "longblob",
129        "tinytext",
130        "text",
131        "mediumtext",
132        "longtext",
133        "json",
134    )
135
136    operators = {
137        "exact": "= %s",
138        "iexact": "LIKE %s",
139        "contains": "LIKE BINARY %s",
140        "icontains": "LIKE %s",
141        "gt": "> %s",
142        "gte": ">= %s",
143        "lt": "< %s",
144        "lte": "<= %s",
145        "startswith": "LIKE BINARY %s",
146        "endswith": "LIKE BINARY %s",
147        "istartswith": "LIKE %s",
148        "iendswith": "LIKE %s",
149    }
150
151    # The patterns below are used to generate SQL pattern lookup clauses when
152    # the right-hand side of the lookup isn't a raw string (it might be an expression
153    # or the result of a bilateral transformation).
154    # In those cases, special characters for LIKE operators (e.g. \, *, _) should be
155    # escaped on database side.
156    #
157    # Note: we use str.format() here for readability as '%' is used as a wildcard for
158    # the LIKE operator.
159    pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')"
160    pattern_ops = {
161        "contains": "LIKE BINARY CONCAT('%%', {}, '%%')",
162        "icontains": "LIKE CONCAT('%%', {}, '%%')",
163        "startswith": "LIKE BINARY CONCAT({}, '%%')",
164        "istartswith": "LIKE CONCAT({}, '%%')",
165        "endswith": "LIKE BINARY CONCAT('%%', {})",
166        "iendswith": "LIKE CONCAT('%%', {})",
167    }
168
169    isolation_levels = {
170        "read uncommitted",
171        "read committed",
172        "repeatable read",
173        "serializable",
174    }
175
176    Database = Database
177    SchemaEditorClass = DatabaseSchemaEditor
178    # Classes instantiated in __init__().
179    client_class = DatabaseClient
180    creation_class = DatabaseCreation
181    features_class = DatabaseFeatures
182    introspection_class = DatabaseIntrospection
183    ops_class = DatabaseOperations
184    validation_class = DatabaseValidation
185
186    def get_database_version(self):
187        return self.mysql_version
188
189    def get_connection_params(self):
190        kwargs = {
191            "conv": plain_conversions,
192            "charset": "utf8",
193        }
194        settings_dict = self.settings_dict
195        if settings_dict["USER"]:
196            kwargs["user"] = settings_dict["USER"]
197        if settings_dict["NAME"]:
198            kwargs["database"] = settings_dict["NAME"]
199        if settings_dict["PASSWORD"]:
200            kwargs["password"] = settings_dict["PASSWORD"]
201        if settings_dict["HOST"].startswith("/"):
202            kwargs["unix_socket"] = settings_dict["HOST"]
203        elif settings_dict["HOST"]:
204            kwargs["host"] = settings_dict["HOST"]
205        if settings_dict["PORT"]:
206            kwargs["port"] = int(settings_dict["PORT"])
207        # We need the number of potentially affected rows after an
208        # "UPDATE", not the number of changed rows.
209        kwargs["client_flag"] = CLIENT.FOUND_ROWS
210        # Validate the transaction isolation level, if specified.
211        options = settings_dict["OPTIONS"].copy()
212        isolation_level = options.pop("isolation_level", "read committed")
213        if isolation_level:
214            isolation_level = isolation_level.lower()
215            if isolation_level not in self.isolation_levels:
216                raise ImproperlyConfigured(
217                    "Invalid transaction isolation level '{}' specified.\n"
218                    "Use one of {}, or None.".format(
219                        isolation_level,
220                        ", ".join(f"'{s}'" for s in sorted(self.isolation_levels)),
221                    )
222                )
223        self.isolation_level = isolation_level
224        kwargs.update(options)
225        return kwargs
226
227    def get_new_connection(self, conn_params):
228        connection = Database.connect(**conn_params)
229        # bytes encoder in mysqlclient doesn't work and was added only to
230        # prevent KeyErrors in Plain < 2.0. We can remove this workaround when
231        # mysqlclient 2.1 becomes the minimal mysqlclient supported by Plain.
232        # See https://github.com/PyMySQL/mysqlclient/issues/489
233        if connection.encoders.get(bytes) is bytes:
234            connection.encoders.pop(bytes)
235        return connection
236
237    def init_connection_state(self):
238        super().init_connection_state()
239        assignments = []
240        if self.features.is_sql_auto_is_null_enabled:
241            # SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on
242            # a recently inserted row will return when the field is tested
243            # for NULL. Disabling this brings this aspect of MySQL in line
244            # with SQL standards.
245            assignments.append("SET SQL_AUTO_IS_NULL = 0")
246
247        if self.isolation_level:
248            assignments.append(
249                f"SET SESSION TRANSACTION ISOLATION LEVEL {self.isolation_level.upper()}"
250            )
251
252        if assignments:
253            with self.cursor() as cursor:
254                cursor.execute("; ".join(assignments))
255
256    def create_cursor(self, name=None):
257        cursor = self.connection.cursor()
258        return CursorWrapper(cursor)
259
260    def _rollback(self):
261        try:
262            BaseDatabaseWrapper._rollback(self)
263        except Database.NotSupportedError:
264            pass
265
266    def _set_autocommit(self, autocommit):
267        with self.wrap_database_errors:
268            self.connection.autocommit(autocommit)
269
270    def check_constraints(self, table_names=None):
271        """Check ``table_names`` for rows with invalid foreign key references."""
272        with self.cursor() as cursor:
273            if table_names is None:
274                table_names = self.introspection.table_names(cursor)
275            for table_name in table_names:
276                primary_key_column_name = self.introspection.get_primary_key_column(
277                    cursor, table_name
278                )
279                if not primary_key_column_name:
280                    continue
281                relations = self.introspection.get_relations(cursor, table_name)
282                for column_name, (
283                    referenced_column_name,
284                    referenced_table_name,
285                ) in relations.items():
286                    cursor.execute(
287                        f"""
288                        SELECT REFERRING.`{primary_key_column_name}`, REFERRING.`{column_name}` FROM `{table_name}` as REFERRING
289                        LEFT JOIN `{referenced_table_name}` as REFERRED
290                        ON (REFERRING.`{column_name}` = REFERRED.`{referenced_column_name}`)
291                        WHERE REFERRING.`{column_name}` IS NOT NULL AND REFERRED.`{referenced_column_name}` IS NULL
292                        """
293                    )
294                    for bad_row in cursor.fetchall():
295                        raise IntegrityError(
296                            f"The row in table '{table_name}' with primary key '{bad_row[0]}' has an "
297                            f"invalid foreign key: {table_name}.{column_name} contains a value '{bad_row[1]}' that "
298                            f"does not have a corresponding value in {referenced_table_name}.{referenced_column_name}."
299                        )
300
301    def is_usable(self):
302        try:
303            self.connection.ping()
304        except Database.Error:
305            return False
306        else:
307            return True
308
309    @cached_property
310    def display_name(self):
311        return "MariaDB" if self.mysql_is_mariadb else "MySQL"
312
313    @cached_property
314    def data_type_check_constraints(self):
315        if self.features.supports_column_check_constraints:
316            check_constraints = {
317                "PositiveBigIntegerField": "`%(column)s` >= 0",
318                "PositiveIntegerField": "`%(column)s` >= 0",
319                "PositiveSmallIntegerField": "`%(column)s` >= 0",
320            }
321            if self.mysql_is_mariadb and self.mysql_version < (10, 4, 3):
322                # MariaDB < 10.4.3 doesn't automatically use the JSON_VALID as
323                # a check constraint.
324                check_constraints["JSONField"] = "JSON_VALID(`%(column)s`)"
325            return check_constraints
326        return {}
327
328    @cached_property
329    def mysql_server_data(self):
330        with self.temporary_connection() as cursor:
331            # Select some server variables and test if the time zone
332            # definitions are installed. CONVERT_TZ returns NULL if 'UTC'
333            # timezone isn't loaded into the mysql.time_zone table.
334            cursor.execute(
335                """
336                SELECT VERSION(),
337                       @@sql_mode,
338                       @@default_storage_engine,
339                       @@sql_auto_is_null,
340                       @@lower_case_table_names,
341                       CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL
342            """
343            )
344            row = cursor.fetchone()
345        return {
346            "version": row[0],
347            "sql_mode": row[1],
348            "default_storage_engine": row[2],
349            "sql_auto_is_null": bool(row[3]),
350            "lower_case_table_names": bool(row[4]),
351            "has_zoneinfo_database": bool(row[5]),
352        }
353
354    @cached_property
355    def mysql_server_info(self):
356        return self.mysql_server_data["version"]
357
358    @cached_property
359    def mysql_version(self):
360        match = server_version_re.match(self.mysql_server_info)
361        if not match:
362            raise Exception(
363                f"Unable to determine MySQL version from version string {self.mysql_server_info!r}"
364            )
365        return tuple(int(x) for x in match.groups())
366
367    @cached_property
368    def mysql_is_mariadb(self):
369        return "mariadb" in self.mysql_server_info.lower()
370
371    @cached_property
372    def sql_mode(self):
373        sql_mode = self.mysql_server_data["sql_mode"]
374        return set(sql_mode.split(",") if sql_mode else ())