Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2MySQL database backend for Plain.
  3
  4Requires mysqlclient: https://pypi.org/project/mysqlclient/
  5"""
  6from plain.exceptions import ImproperlyConfigured
  7from plain.models.backends import utils as backend_utils
  8from plain.models.backends.base.base import BaseDatabaseWrapper
  9from plain.models.db import IntegrityError
 10from plain.utils.functional import cached_property
 11from plain.utils.regex_helper import _lazy_re_compile
 12
 13try:
 14    import MySQLdb as Database
 15except ImportError as err:
 16    raise ImproperlyConfigured(
 17        "Error loading MySQLdb module.\nDid you install mysqlclient?"
 18    ) from err
 19
 20from MySQLdb.constants import CLIENT, FIELD_TYPE
 21from MySQLdb.converters import conversions
 22
 23# Some of these import MySQLdb, so import them after checking if it's installed.
 24from .client import DatabaseClient
 25from .creation import DatabaseCreation
 26from .features import DatabaseFeatures
 27from .introspection import DatabaseIntrospection
 28from .operations import DatabaseOperations
 29from .schema import DatabaseSchemaEditor
 30from .validation import DatabaseValidation
 31
 32version = Database.version_info
 33if version < (1, 4, 3):
 34    raise ImproperlyConfigured(
 35        "mysqlclient 1.4.3 or newer is required; you have %s." % Database.__version__
 36    )
 37
 38
 39# MySQLdb returns TIME columns as timedelta -- they are more like timedelta in
 40# terms of actual behavior as they are signed and include days -- and Plain
 41# expects time.
 42plain_conversions = {
 43    **conversions,
 44    **{FIELD_TYPE.TIME: backend_utils.typecast_time},
 45}
 46
 47# This should match the numerical portion of the version numbers (we can treat
 48# versions like 5.0.24 and 5.0.24a as the same).
 49server_version_re = _lazy_re_compile(r"(\d{1,2})\.(\d{1,2})\.(\d{1,2})")
 50
 51
 52class CursorWrapper:
 53    """
 54    A thin wrapper around MySQLdb's normal cursor class that catches particular
 55    exception instances and reraises them with the correct types.
 56
 57    Implemented as a wrapper, rather than a subclass, so that it isn't stuck
 58    to the particular underlying representation returned by Connection.cursor().
 59    """
 60
 61    codes_for_integrityerror = (
 62        1048,  # Column cannot be null
 63        1690,  # BIGINT UNSIGNED value is out of range
 64        3819,  # CHECK constraint is violated
 65        4025,  # CHECK constraint failed
 66    )
 67
 68    def __init__(self, cursor):
 69        self.cursor = cursor
 70
 71    def execute(self, query, args=None):
 72        try:
 73            # args is None means no string interpolation
 74            return self.cursor.execute(query, args)
 75        except Database.OperationalError as e:
 76            # Map some error codes to IntegrityError, since they seem to be
 77            # misclassified and Plain would prefer the more logical place.
 78            if e.args[0] in self.codes_for_integrityerror:
 79                raise IntegrityError(*tuple(e.args))
 80            raise
 81
 82    def executemany(self, query, args):
 83        try:
 84            return self.cursor.executemany(query, args)
 85        except Database.OperationalError as e:
 86            # Map some error codes to IntegrityError, since they seem to be
 87            # misclassified and Plain would prefer the more logical place.
 88            if e.args[0] in self.codes_for_integrityerror:
 89                raise IntegrityError(*tuple(e.args))
 90            raise
 91
 92    def __getattr__(self, attr):
 93        return getattr(self.cursor, attr)
 94
 95    def __iter__(self):
 96        return iter(self.cursor)
 97
 98
 99class DatabaseWrapper(BaseDatabaseWrapper):
100    vendor = "mysql"
101    # This dictionary maps Field objects to their associated MySQL column
102    # types, as strings. Column-type strings can contain format strings; they'll
103    # be interpolated against the values of Field.__dict__ before being output.
104    # If a column type is set to None, it won't be included in the output.
105    data_types = {
106        "AutoField": "integer AUTO_INCREMENT",
107        "BigAutoField": "bigint AUTO_INCREMENT",
108        "BinaryField": "longblob",
109        "BooleanField": "bool",
110        "CharField": "varchar(%(max_length)s)",
111        "DateField": "date",
112        "DateTimeField": "datetime(6)",
113        "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
114        "DurationField": "bigint",
115        "FloatField": "double precision",
116        "IntegerField": "integer",
117        "BigIntegerField": "bigint",
118        "IPAddressField": "char(15)",
119        "GenericIPAddressField": "char(39)",
120        "JSONField": "json",
121        "OneToOneField": "integer",
122        "PositiveBigIntegerField": "bigint UNSIGNED",
123        "PositiveIntegerField": "integer UNSIGNED",
124        "PositiveSmallIntegerField": "smallint UNSIGNED",
125        "SlugField": "varchar(%(max_length)s)",
126        "SmallAutoField": "smallint AUTO_INCREMENT",
127        "SmallIntegerField": "smallint",
128        "TextField": "longtext",
129        "TimeField": "time(6)",
130        "UUIDField": "char(32)",
131    }
132
133    # For these data types:
134    # - MySQL < 8.0.13 doesn't accept default values and implicitly treats them
135    #   as nullable
136    # - all versions of MySQL and MariaDB don't support full width database
137    #   indexes
138    _limited_data_types = (
139        "tinyblob",
140        "blob",
141        "mediumblob",
142        "longblob",
143        "tinytext",
144        "text",
145        "mediumtext",
146        "longtext",
147        "json",
148    )
149
150    operators = {
151        "exact": "= %s",
152        "iexact": "LIKE %s",
153        "contains": "LIKE BINARY %s",
154        "icontains": "LIKE %s",
155        "gt": "> %s",
156        "gte": ">= %s",
157        "lt": "< %s",
158        "lte": "<= %s",
159        "startswith": "LIKE BINARY %s",
160        "endswith": "LIKE BINARY %s",
161        "istartswith": "LIKE %s",
162        "iendswith": "LIKE %s",
163    }
164
165    # The patterns below are used to generate SQL pattern lookup clauses when
166    # the right-hand side of the lookup isn't a raw string (it might be an expression
167    # or the result of a bilateral transformation).
168    # In those cases, special characters for LIKE operators (e.g. \, *, _) should be
169    # escaped on database side.
170    #
171    # Note: we use str.format() here for readability as '%' is used as a wildcard for
172    # the LIKE operator.
173    pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')"
174    pattern_ops = {
175        "contains": "LIKE BINARY CONCAT('%%', {}, '%%')",
176        "icontains": "LIKE CONCAT('%%', {}, '%%')",
177        "startswith": "LIKE BINARY CONCAT({}, '%%')",
178        "istartswith": "LIKE CONCAT({}, '%%')",
179        "endswith": "LIKE BINARY CONCAT('%%', {})",
180        "iendswith": "LIKE CONCAT('%%', {})",
181    }
182
183    isolation_levels = {
184        "read uncommitted",
185        "read committed",
186        "repeatable read",
187        "serializable",
188    }
189
190    Database = Database
191    SchemaEditorClass = DatabaseSchemaEditor
192    # Classes instantiated in __init__().
193    client_class = DatabaseClient
194    creation_class = DatabaseCreation
195    features_class = DatabaseFeatures
196    introspection_class = DatabaseIntrospection
197    ops_class = DatabaseOperations
198    validation_class = DatabaseValidation
199
200    def get_database_version(self):
201        return self.mysql_version
202
203    def get_connection_params(self):
204        kwargs = {
205            "conv": plain_conversions,
206            "charset": "utf8",
207        }
208        settings_dict = self.settings_dict
209        if settings_dict["USER"]:
210            kwargs["user"] = settings_dict["USER"]
211        if settings_dict["NAME"]:
212            kwargs["database"] = settings_dict["NAME"]
213        if settings_dict["PASSWORD"]:
214            kwargs["password"] = settings_dict["PASSWORD"]
215        if settings_dict["HOST"].startswith("/"):
216            kwargs["unix_socket"] = settings_dict["HOST"]
217        elif settings_dict["HOST"]:
218            kwargs["host"] = settings_dict["HOST"]
219        if settings_dict["PORT"]:
220            kwargs["port"] = int(settings_dict["PORT"])
221        # We need the number of potentially affected rows after an
222        # "UPDATE", not the number of changed rows.
223        kwargs["client_flag"] = CLIENT.FOUND_ROWS
224        # Validate the transaction isolation level, if specified.
225        options = settings_dict["OPTIONS"].copy()
226        isolation_level = options.pop("isolation_level", "read committed")
227        if isolation_level:
228            isolation_level = isolation_level.lower()
229            if isolation_level not in self.isolation_levels:
230                raise ImproperlyConfigured(
231                    "Invalid transaction isolation level '{}' specified.\n"
232                    "Use one of {}, or None.".format(
233                        isolation_level,
234                        ", ".join("'%s'" % s for s in sorted(self.isolation_levels)),
235                    )
236                )
237        self.isolation_level = isolation_level
238        kwargs.update(options)
239        return kwargs
240
241    def get_new_connection(self, conn_params):
242        connection = Database.connect(**conn_params)
243        # bytes encoder in mysqlclient doesn't work and was added only to
244        # prevent KeyErrors in Plain < 2.0. We can remove this workaround when
245        # mysqlclient 2.1 becomes the minimal mysqlclient supported by Plain.
246        # See https://github.com/PyMySQL/mysqlclient/issues/489
247        if connection.encoders.get(bytes) is bytes:
248            connection.encoders.pop(bytes)
249        return connection
250
251    def init_connection_state(self):
252        super().init_connection_state()
253        assignments = []
254        if self.features.is_sql_auto_is_null_enabled:
255            # SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on
256            # a recently inserted row will return when the field is tested
257            # for NULL. Disabling this brings this aspect of MySQL in line
258            # with SQL standards.
259            assignments.append("SET SQL_AUTO_IS_NULL = 0")
260
261        if self.isolation_level:
262            assignments.append(
263                "SET SESSION TRANSACTION ISOLATION LEVEL %s"
264                % self.isolation_level.upper()
265            )
266
267        if assignments:
268            with self.cursor() as cursor:
269                cursor.execute("; ".join(assignments))
270
271    def create_cursor(self, name=None):
272        cursor = self.connection.cursor()
273        return CursorWrapper(cursor)
274
275    def _rollback(self):
276        try:
277            BaseDatabaseWrapper._rollback(self)
278        except Database.NotSupportedError:
279            pass
280
281    def _set_autocommit(self, autocommit):
282        with self.wrap_database_errors:
283            self.connection.autocommit(autocommit)
284
285    def disable_constraint_checking(self):
286        """
287        Disable foreign key checks, primarily for use in adding rows with
288        forward references. Always return True to indicate constraint checks
289        need to be re-enabled.
290        """
291        with self.cursor() as cursor:
292            cursor.execute("SET foreign_key_checks=0")
293        return True
294
295    def enable_constraint_checking(self):
296        """
297        Re-enable foreign key checks after they have been disabled.
298        """
299        # Override needs_rollback in case constraint_checks_disabled is
300        # nested inside transaction.atomic.
301        self.needs_rollback, needs_rollback = False, self.needs_rollback
302        try:
303            with self.cursor() as cursor:
304                cursor.execute("SET foreign_key_checks=1")
305        finally:
306            self.needs_rollback = needs_rollback
307
308    def check_constraints(self, table_names=None):
309        """
310        Check each table name in `table_names` for rows with invalid foreign
311        key references. This method is intended to be used in conjunction with
312        `disable_constraint_checking()` and `enable_constraint_checking()`, to
313        determine if rows with invalid references were entered while constraint
314        checks were off.
315        """
316        with self.cursor() as cursor:
317            if table_names is None:
318                table_names = self.introspection.table_names(cursor)
319            for table_name in table_names:
320                primary_key_column_name = self.introspection.get_primary_key_column(
321                    cursor, table_name
322                )
323                if not primary_key_column_name:
324                    continue
325                relations = self.introspection.get_relations(cursor, table_name)
326                for column_name, (
327                    referenced_column_name,
328                    referenced_table_name,
329                ) in relations.items():
330                    cursor.execute(
331                        """
332                        SELECT REFERRING.`{}`, REFERRING.`{}` FROM `{}` as REFERRING
333                        LEFT JOIN `{}` as REFERRED
334                        ON (REFERRING.`{}` = REFERRED.`{}`)
335                        WHERE REFERRING.`{}` IS NOT NULL AND REFERRED.`{}` IS NULL
336                        """.format(
337                            primary_key_column_name,
338                            column_name,
339                            table_name,
340                            referenced_table_name,
341                            column_name,
342                            referenced_column_name,
343                            column_name,
344                            referenced_column_name,
345                        )
346                    )
347                    for bad_row in cursor.fetchall():
348                        raise IntegrityError(
349                            "The row in table '{}' with primary key '{}' has an "
350                            "invalid foreign key: {}.{} contains a value '{}' that "
351                            "does not have a corresponding value in {}.{}.".format(
352                                table_name,
353                                bad_row[0],
354                                table_name,
355                                column_name,
356                                bad_row[1],
357                                referenced_table_name,
358                                referenced_column_name,
359                            )
360                        )
361
362    def is_usable(self):
363        try:
364            self.connection.ping()
365        except Database.Error:
366            return False
367        else:
368            return True
369
370    @cached_property
371    def display_name(self):
372        return "MariaDB" if self.mysql_is_mariadb else "MySQL"
373
374    @cached_property
375    def data_type_check_constraints(self):
376        if self.features.supports_column_check_constraints:
377            check_constraints = {
378                "PositiveBigIntegerField": "`%(column)s` >= 0",
379                "PositiveIntegerField": "`%(column)s` >= 0",
380                "PositiveSmallIntegerField": "`%(column)s` >= 0",
381            }
382            if self.mysql_is_mariadb and self.mysql_version < (10, 4, 3):
383                # MariaDB < 10.4.3 doesn't automatically use the JSON_VALID as
384                # a check constraint.
385                check_constraints["JSONField"] = "JSON_VALID(`%(column)s`)"
386            return check_constraints
387        return {}
388
389    @cached_property
390    def mysql_server_data(self):
391        with self.temporary_connection() as cursor:
392            # Select some server variables and test if the time zone
393            # definitions are installed. CONVERT_TZ returns NULL if 'UTC'
394            # timezone isn't loaded into the mysql.time_zone table.
395            cursor.execute(
396                """
397                SELECT VERSION(),
398                       @@sql_mode,
399                       @@default_storage_engine,
400                       @@sql_auto_is_null,
401                       @@lower_case_table_names,
402                       CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL
403            """
404            )
405            row = cursor.fetchone()
406        return {
407            "version": row[0],
408            "sql_mode": row[1],
409            "default_storage_engine": row[2],
410            "sql_auto_is_null": bool(row[3]),
411            "lower_case_table_names": bool(row[4]),
412            "has_zoneinfo_database": bool(row[5]),
413        }
414
415    @cached_property
416    def mysql_server_info(self):
417        return self.mysql_server_data["version"]
418
419    @cached_property
420    def mysql_version(self):
421        match = server_version_re.match(self.mysql_server_info)
422        if not match:
423            raise Exception(
424                "Unable to determine MySQL version from version string %r"
425                % self.mysql_server_info
426            )
427        return tuple(int(x) for x in match.groups())
428
429    @cached_property
430    def mysql_is_mariadb(self):
431        return "mariadb" in self.mysql_server_info.lower()
432
433    @cached_property
434    def sql_mode(self):
435        sql_mode = self.mysql_server_data["sql_mode"]
436        return set(sql_mode.split(",") if sql_mode else ())