Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from collections import namedtuple
  4from typing import Any
  5
  6import sqlparse
  7from MySQLdb.constants import FIELD_TYPE  # type: ignore[import-untyped]
  8
  9from plain.models.backends.base.introspection import BaseDatabaseIntrospection
 10from plain.models.backends.base.introspection import FieldInfo as BaseFieldInfo
 11from plain.models.backends.base.introspection import TableInfo as BaseTableInfo
 12from plain.models.indexes import Index
 13from plain.utils.datastructures import OrderedSet
 14
 15FieldInfo = namedtuple(
 16    "FieldInfo",
 17    BaseFieldInfo._fields + ("extra", "is_unsigned", "has_json_constraint", "comment"),
 18)
 19InfoLine = namedtuple(
 20    "InfoLine",
 21    "col_name data_type max_len num_prec num_scale extra column_default "
 22    "collation is_unsigned comment",
 23)
 24TableInfo = namedtuple("TableInfo", BaseTableInfo._fields + ("comment",))
 25
 26
 27class DatabaseIntrospection(BaseDatabaseIntrospection):
 28    data_types_reverse = {
 29        FIELD_TYPE.BLOB: "TextField",
 30        FIELD_TYPE.CHAR: "CharField",
 31        FIELD_TYPE.DECIMAL: "DecimalField",
 32        FIELD_TYPE.NEWDECIMAL: "DecimalField",
 33        FIELD_TYPE.DATE: "DateField",
 34        FIELD_TYPE.DATETIME: "DateTimeField",
 35        FIELD_TYPE.DOUBLE: "FloatField",
 36        FIELD_TYPE.FLOAT: "FloatField",
 37        FIELD_TYPE.INT24: "IntegerField",
 38        FIELD_TYPE.JSON: "JSONField",
 39        FIELD_TYPE.LONG: "IntegerField",
 40        FIELD_TYPE.LONGLONG: "BigIntegerField",
 41        FIELD_TYPE.SHORT: "SmallIntegerField",
 42        FIELD_TYPE.STRING: "CharField",
 43        FIELD_TYPE.TIME: "TimeField",
 44        FIELD_TYPE.TIMESTAMP: "DateTimeField",
 45        FIELD_TYPE.TINY: "IntegerField",
 46        FIELD_TYPE.TINY_BLOB: "TextField",
 47        FIELD_TYPE.MEDIUM_BLOB: "TextField",
 48        FIELD_TYPE.LONG_BLOB: "TextField",
 49        FIELD_TYPE.VAR_STRING: "CharField",
 50    }
 51
 52    def get_field_type(self, data_type: Any, description: Any) -> str:
 53        field_type = super().get_field_type(data_type, description)
 54        if "auto_increment" in description.extra:
 55            if field_type == "BigIntegerField":
 56                return "PrimaryKeyField"
 57        if description.is_unsigned:
 58            if field_type == "BigIntegerField":
 59                return "PositiveBigIntegerField"
 60            elif field_type == "IntegerField":
 61                return "PositiveIntegerField"
 62            elif field_type == "SmallIntegerField":
 63                return "PositiveSmallIntegerField"
 64        # JSON data type is an alias for LONGTEXT in MariaDB, use check
 65        # constraints clauses to introspect JSONField.
 66        if description.has_json_constraint:
 67            return "JSONField"
 68        return field_type
 69
 70    def get_table_list(self, cursor: Any) -> list[TableInfo]:
 71        """Return a list of table and view names in the current database."""
 72        cursor.execute(
 73            """
 74            SELECT
 75                table_name,
 76                table_type,
 77                table_comment
 78            FROM information_schema.tables
 79            WHERE table_schema = DATABASE()
 80            """
 81        )
 82        return [
 83            TableInfo(row[0], {"BASE TABLE": "t", "VIEW": "v"}.get(row[1]), row[2])
 84            for row in cursor.fetchall()
 85        ]
 86
 87    def get_table_description(self, cursor: Any, table_name: str) -> list[FieldInfo]:
 88        """
 89        Return a description of the table with the DB-API cursor.description
 90        interface."
 91        """
 92        json_constraints: set[Any] = set()
 93        if (
 94            self.connection.mysql_is_mariadb
 95            and self.connection.features.can_introspect_json_field
 96        ):
 97            # JSON data type is an alias for LONGTEXT in MariaDB, select
 98            # JSON_VALID() constraints to introspect JSONField.
 99            cursor.execute(
100                """
101                SELECT c.constraint_name AS column_name
102                FROM information_schema.check_constraints AS c
103                WHERE
104                    c.table_name = %s AND
105                    LOWER(c.check_clause) =
106                        'json_valid(`' + LOWER(c.constraint_name) + '`)' AND
107                    c.constraint_schema = DATABASE()
108                """,
109                [table_name],
110            )
111            json_constraints = {row[0] for row in cursor.fetchall()}
112        # A default collation for the given table.
113        cursor.execute(
114            """
115            SELECT  table_collation
116            FROM    information_schema.tables
117            WHERE   table_schema = DATABASE()
118            AND     table_name = %s
119            """,
120            [table_name],
121        )
122        row = cursor.fetchone()
123        default_column_collation = row[0] if row else ""
124        # information_schema database gives more accurate results for some figures:
125        # - varchar length returned by cursor.description is an internal length,
126        #   not visible length (#5725)
127        # - precision and scale (for decimal fields) (#5014)
128        # - auto_increment is not available in cursor.description
129        cursor.execute(
130            """
131            SELECT
132                column_name, data_type, character_maximum_length,
133                numeric_precision, numeric_scale, extra, column_default,
134                CASE
135                    WHEN collation_name = %s THEN NULL
136                    ELSE collation_name
137                END AS collation_name,
138                CASE
139                    WHEN column_type LIKE '%% unsigned' THEN 1
140                    ELSE 0
141                END AS is_unsigned,
142                column_comment
143            FROM information_schema.columns
144            WHERE table_name = %s AND table_schema = DATABASE()
145            """,
146            [default_column_collation, table_name],
147        )
148        field_info = {line[0]: InfoLine(*line) for line in cursor.fetchall()}
149
150        cursor.execute(
151            f"SELECT * FROM {self.connection.ops.quote_name(table_name)} LIMIT 1"
152        )
153
154        def to_int(i: Any) -> Any:
155            return int(i) if i is not None else i
156
157        fields = []
158        for line in cursor.description:
159            info = field_info[line[0]]
160            fields.append(
161                FieldInfo(
162                    *line[:2],
163                    to_int(info.max_len) or line[2],
164                    to_int(info.max_len) or line[3],
165                    to_int(info.num_prec) or line[4],
166                    to_int(info.num_scale) or line[5],
167                    line[6],
168                    info.column_default,
169                    info.collation,
170                    info.extra,
171                    info.is_unsigned,
172                    line[0] in json_constraints,
173                    info.comment,
174                )
175            )
176        return fields
177
178    def get_sequences(
179        self, cursor: Any, table_name: str, table_fields: tuple[Any, ...] = ()
180    ) -> list[dict[str, Any]]:
181        for field_info in self.get_table_description(cursor, table_name):
182            if "auto_increment" in field_info.extra:
183                # MySQL allows only one auto-increment column per table.
184                return [{"table": table_name, "column": field_info.name}]
185        return []
186
187    def get_relations(self, cursor: Any, table_name: str) -> dict[str, tuple[str, str]]:
188        """
189        Return a dictionary of {field_name: (field_name_other_table, other_table)}
190        representing all foreign keys in the given table.
191        """
192        cursor.execute(
193            """
194            SELECT column_name, referenced_column_name, referenced_table_name
195            FROM information_schema.key_column_usage
196            WHERE table_name = %s
197                AND table_schema = DATABASE()
198                AND referenced_table_name IS NOT NULL
199                AND referenced_column_name IS NOT NULL
200            """,
201            [table_name],
202        )
203        return {
204            field_name: (other_field, other_table)
205            for field_name, other_field, other_table in cursor.fetchall()
206        }
207
208    def get_storage_engine(self, cursor: Any, table_name: str) -> str:
209        """
210        Retrieve the storage engine for a given table. Return the default
211        storage engine if the table doesn't exist.
212        """
213        cursor.execute(
214            """
215            SELECT engine
216            FROM information_schema.tables
217            WHERE
218                table_name = %s AND
219                table_schema = DATABASE()
220            """,
221            [table_name],
222        )
223        result = cursor.fetchone()
224        if not result:
225            return self.connection.features._mysql_storage_engine
226        return result[0]
227
228    def _parse_constraint_columns(
229        self, check_clause: str, columns: set[str]
230    ) -> OrderedSet:
231        check_columns: OrderedSet = OrderedSet()
232        statement = sqlparse.parse(check_clause)[0]
233        tokens = (token for token in statement.flatten() if not token.is_whitespace)
234        for token in tokens:
235            if (
236                token.ttype == sqlparse.tokens.Name
237                and self.connection.ops.quote_name(token.value) == token.value
238                and token.value[1:-1] in columns
239            ):
240                check_columns.add(token.value[1:-1])
241        return check_columns
242
243    def get_constraints(
244        self, cursor: Any, table_name: str
245    ) -> dict[str, dict[str, Any]]:
246        """
247        Retrieve any constraints or keys (unique, pk, fk, check, index) across
248        one or more columns.
249        """
250        constraints: dict[str, dict[str, Any]] = {}
251        # Get the actual constraint names and columns
252        name_query = """
253            SELECT kc.`constraint_name`, kc.`column_name`,
254                kc.`referenced_table_name`, kc.`referenced_column_name`,
255                c.`constraint_type`
256            FROM
257                information_schema.key_column_usage AS kc,
258                information_schema.table_constraints AS c
259            WHERE
260                kc.table_schema = DATABASE() AND
261                c.table_schema = kc.table_schema AND
262                c.constraint_name = kc.constraint_name AND
263                c.constraint_type != 'CHECK' AND
264                kc.table_name = %s
265            ORDER BY kc.`ordinal_position`
266        """
267        cursor.execute(name_query, [table_name])
268        for constraint, column, ref_table, ref_column, kind in cursor.fetchall():
269            if constraint not in constraints:
270                constraints[constraint] = {
271                    "columns": OrderedSet(),
272                    "primary_key": kind == "PRIMARY KEY",
273                    "unique": kind in {"PRIMARY KEY", "UNIQUE"},
274                    "index": False,
275                    "check": False,
276                    "foreign_key": (ref_table, ref_column) if ref_column else None,
277                }
278                if self.connection.features.supports_index_column_ordering:
279                    constraints[constraint]["orders"] = []
280            constraints[constraint]["columns"].add(column)
281        # Add check constraints.
282        if self.connection.features.can_introspect_check_constraints:
283            unnamed_constraints_index = 0
284            columns = {
285                info.name for info in self.get_table_description(cursor, table_name)
286            }
287            if self.connection.mysql_is_mariadb:
288                type_query = """
289                    SELECT c.constraint_name, c.check_clause
290                    FROM information_schema.check_constraints AS c
291                    WHERE
292                        c.constraint_schema = DATABASE() AND
293                        c.table_name = %s
294                """
295            else:
296                type_query = """
297                    SELECT cc.constraint_name, cc.check_clause
298                    FROM
299                        information_schema.check_constraints AS cc,
300                        information_schema.table_constraints AS tc
301                    WHERE
302                        cc.constraint_schema = DATABASE() AND
303                        tc.table_schema = cc.constraint_schema AND
304                        cc.constraint_name = tc.constraint_name AND
305                        tc.constraint_type = 'CHECK' AND
306                        tc.table_name = %s
307                """
308            cursor.execute(type_query, [table_name])
309            for constraint, check_clause in cursor.fetchall():
310                constraint_columns = self._parse_constraint_columns(
311                    check_clause, columns
312                )
313                # Ensure uniqueness of unnamed constraints. Unnamed unique
314                # and check columns constraints have the same name as
315                # a column.
316                if set(constraint_columns) == {constraint}:
317                    unnamed_constraints_index += 1
318                    constraint = f"__unnamed_constraint_{unnamed_constraints_index}__"
319                constraints[constraint] = {
320                    "columns": constraint_columns,
321                    "primary_key": False,
322                    "unique": False,
323                    "index": False,
324                    "check": True,
325                    "foreign_key": None,
326                }
327        # Now add in the indexes
328        cursor.execute(f"SHOW INDEX FROM {self.connection.ops.quote_name(table_name)}")
329        for table, non_unique, index, colseq, column, order, type_ in [
330            x[:6] + (x[10],) for x in cursor.fetchall()
331        ]:
332            if index not in constraints:
333                constraints[index] = {
334                    "columns": OrderedSet(),
335                    "primary_key": False,
336                    "unique": not non_unique,
337                    "check": False,
338                    "foreign_key": None,
339                }
340                if self.connection.features.supports_index_column_ordering:
341                    constraints[index]["orders"] = []
342            constraints[index]["index"] = True
343            constraints[index]["type"] = (
344                Index.suffix if type_ == "BTREE" else type_.lower()
345            )
346            constraints[index]["columns"].add(column)
347            if self.connection.features.supports_index_column_ordering:
348                constraints[index]["orders"].append("DESC" if order == "D" else "ASC")
349        # Convert the sorted sets to lists
350        for constraint in constraints.values():
351            constraint["columns"] = list(constraint["columns"])
352        return constraints