1from __future__ import annotations
2
3from typing import TYPE_CHECKING, Any, NamedTuple
4
5import sqlparse
6from MySQLdb.constants import FIELD_TYPE
7
8from plain.models.backends.base.introspection import (
9 BaseDatabaseIntrospection,
10)
11from plain.models.backends.utils import CursorWrapper
12from plain.models.indexes import Index
13from plain.utils.datastructures import OrderedSet
14
15if TYPE_CHECKING:
16 from .base import MySQLDatabaseWrapper
17
18
19class FieldInfo(NamedTuple):
20 """MySQL-specific FieldInfo extending base with additional metadata."""
21
22 # Fields from BaseFieldInfo
23 name: str
24 type_code: Any
25 display_size: int | None
26 internal_size: int | None
27 precision: int | None
28 scale: int | None
29 null_ok: bool | None
30 default: Any
31 collation: str | None
32 # MySQL-specific extensions
33 extra: str
34 is_unsigned: bool
35 has_json_constraint: bool
36 comment: str | None
37
38
39class InfoLine(NamedTuple):
40 """Information about a column from MySQL's information schema."""
41
42 col_name: str
43 data_type: str
44 max_len: int | None
45 num_prec: int | None
46 num_scale: int | None
47 extra: str
48 column_default: Any
49 collation: str | None
50 is_unsigned: bool
51 comment: str | None
52
53
54class TableInfo(NamedTuple):
55 """MySQL-specific TableInfo extending base with comment support."""
56
57 # Fields from BaseTableInfo
58 name: str
59 type: str
60 # MySQL-specific extension
61 comment: str | None
62
63
64class DatabaseIntrospection(BaseDatabaseIntrospection):
65 # Type hint: narrow connection type to MySQL-specific wrapper
66 connection: MySQLDatabaseWrapper
67
68 data_types_reverse = {
69 FIELD_TYPE.BLOB: "TextField",
70 FIELD_TYPE.CHAR: "CharField",
71 FIELD_TYPE.DECIMAL: "DecimalField",
72 FIELD_TYPE.NEWDECIMAL: "DecimalField",
73 FIELD_TYPE.DATE: "DateField",
74 FIELD_TYPE.DATETIME: "DateTimeField",
75 FIELD_TYPE.DOUBLE: "FloatField",
76 FIELD_TYPE.FLOAT: "FloatField",
77 FIELD_TYPE.INT24: "IntegerField",
78 FIELD_TYPE.JSON: "JSONField",
79 FIELD_TYPE.LONG: "IntegerField",
80 FIELD_TYPE.LONGLONG: "BigIntegerField",
81 FIELD_TYPE.SHORT: "SmallIntegerField",
82 FIELD_TYPE.STRING: "CharField",
83 FIELD_TYPE.TIME: "TimeField",
84 FIELD_TYPE.TIMESTAMP: "DateTimeField",
85 FIELD_TYPE.TINY: "IntegerField",
86 FIELD_TYPE.TINY_BLOB: "TextField",
87 FIELD_TYPE.MEDIUM_BLOB: "TextField",
88 FIELD_TYPE.LONG_BLOB: "TextField",
89 FIELD_TYPE.VAR_STRING: "CharField",
90 }
91
92 def get_field_type(self, data_type: Any, description: Any) -> str:
93 field_type = super().get_field_type(data_type, description)
94 if "auto_increment" in description.extra:
95 if field_type == "BigIntegerField":
96 return "PrimaryKeyField"
97 if description.is_unsigned:
98 if field_type == "BigIntegerField":
99 return "PositiveBigIntegerField"
100 elif field_type == "IntegerField":
101 return "PositiveIntegerField"
102 elif field_type == "SmallIntegerField":
103 return "PositiveSmallIntegerField"
104 # JSON data type is an alias for LONGTEXT in MariaDB, use check
105 # constraints clauses to introspect JSONField.
106 if description.has_json_constraint:
107 return "JSONField"
108 return field_type
109
110 def get_table_list(self, cursor: CursorWrapper) -> list[TableInfo]:
111 """Return a list of table and view names in the current database."""
112 cursor.execute(
113 """
114 SELECT
115 table_name,
116 table_type,
117 table_comment
118 FROM information_schema.tables
119 WHERE table_schema = DATABASE()
120 """
121 )
122 return [
123 TableInfo(row[0], {"BASE TABLE": "t", "VIEW": "v"}.get(row[1], "t"), row[2])
124 for row in cursor.fetchall()
125 ]
126
127 def get_table_description(
128 self, cursor: CursorWrapper, table_name: str
129 ) -> list[FieldInfo]:
130 """
131 Return a description of the table with the DB-API cursor.description
132 interface."
133 """
134 json_constraints: set[Any] = set()
135 if (
136 self.connection.mysql_is_mariadb
137 and self.connection.features.can_introspect_json_field
138 ):
139 # JSON data type is an alias for LONGTEXT in MariaDB, select
140 # JSON_VALID() constraints to introspect JSONField.
141 cursor.execute(
142 """
143 SELECT c.constraint_name AS column_name
144 FROM information_schema.check_constraints AS c
145 WHERE
146 c.table_name = %s AND
147 LOWER(c.check_clause) =
148 'json_valid(`' + LOWER(c.constraint_name) + '`)' AND
149 c.constraint_schema = DATABASE()
150 """,
151 [table_name],
152 )
153 json_constraints = {row[0] for row in cursor.fetchall()}
154 # A default collation for the given table.
155 cursor.execute(
156 """
157 SELECT table_collation
158 FROM information_schema.tables
159 WHERE table_schema = DATABASE()
160 AND table_name = %s
161 """,
162 [table_name],
163 )
164 row = cursor.fetchone()
165 default_column_collation = row[0] if row else ""
166 # information_schema database gives more accurate results for some figures:
167 # - varchar length returned by cursor.description is an internal length,
168 # not visible length (#5725)
169 # - precision and scale (for decimal fields) (#5014)
170 # - auto_increment is not available in cursor.description
171 cursor.execute(
172 """
173 SELECT
174 column_name, data_type, character_maximum_length,
175 numeric_precision, numeric_scale, extra, column_default,
176 CASE
177 WHEN collation_name = %s THEN NULL
178 ELSE collation_name
179 END AS collation_name,
180 CASE
181 WHEN column_type LIKE '%% unsigned' THEN 1
182 ELSE 0
183 END AS is_unsigned,
184 column_comment
185 FROM information_schema.columns
186 WHERE table_name = %s AND table_schema = DATABASE()
187 """,
188 [default_column_collation, table_name],
189 )
190 field_info = {line[0]: InfoLine(*line) for line in cursor.fetchall()}
191
192 cursor.execute(
193 f"SELECT * FROM {self.connection.ops.quote_name(table_name)} LIMIT 1"
194 )
195
196 def to_int(i: Any) -> Any:
197 return int(i) if i is not None else i
198
199 fields = []
200 for line in cursor.description:
201 info = field_info[line[0]]
202 fields.append(
203 FieldInfo(
204 name=line[0],
205 type_code=line[1],
206 display_size=to_int(info.max_len) or line[2],
207 internal_size=to_int(info.max_len) or line[3],
208 precision=to_int(info.num_prec) or line[4],
209 scale=to_int(info.num_scale) or line[5],
210 null_ok=line[6],
211 default=info.column_default,
212 collation=info.collation,
213 extra=info.extra,
214 is_unsigned=info.is_unsigned,
215 has_json_constraint=line[0] in json_constraints,
216 comment=info.comment,
217 )
218 )
219 return fields
220
221 def get_sequences(
222 self, cursor: CursorWrapper, table_name: str, table_fields: tuple[Any, ...] = ()
223 ) -> list[dict[str, Any]]:
224 for field_info in self.get_table_description(cursor, table_name):
225 if "auto_increment" in field_info.extra:
226 # MySQL allows only one auto-increment column per table.
227 return [{"table": table_name, "column": field_info.name}]
228 return []
229
230 def get_relations(
231 self, cursor: CursorWrapper, table_name: str
232 ) -> dict[str, tuple[str, str]]:
233 """
234 Return a dictionary of {field_name: (field_name_other_table, other_table)}
235 representing all foreign keys in the given table.
236 """
237 cursor.execute(
238 """
239 SELECT column_name, referenced_column_name, referenced_table_name
240 FROM information_schema.key_column_usage
241 WHERE table_name = %s
242 AND table_schema = DATABASE()
243 AND referenced_table_name IS NOT NULL
244 AND referenced_column_name IS NOT NULL
245 """,
246 [table_name],
247 )
248 return {
249 field_name: (other_field, other_table)
250 for field_name, other_field, other_table in cursor.fetchall()
251 }
252
253 def get_storage_engine(self, cursor: CursorWrapper, table_name: str) -> str:
254 """
255 Retrieve the storage engine for a given table. Return the default
256 storage engine if the table doesn't exist.
257 """
258 cursor.execute(
259 """
260 SELECT engine
261 FROM information_schema.tables
262 WHERE
263 table_name = %s AND
264 table_schema = DATABASE()
265 """,
266 [table_name],
267 )
268 result = cursor.fetchone()
269 if not result:
270 return self.connection.features._mysql_storage_engine
271 return result[0]
272
273 def _parse_constraint_columns(
274 self, check_clause: str, columns: set[str]
275 ) -> OrderedSet:
276 check_columns: OrderedSet = OrderedSet()
277 statement = sqlparse.parse(check_clause)[0]
278 tokens = (token for token in statement.flatten() if not token.is_whitespace)
279 for token in tokens:
280 if (
281 token.ttype == sqlparse.tokens.Name
282 and self.connection.ops.quote_name(token.value) == token.value
283 and token.value[1:-1] in columns
284 ):
285 check_columns.add(token.value[1:-1])
286 return check_columns
287
288 def get_constraints(
289 self, cursor: CursorWrapper, table_name: str
290 ) -> dict[str, dict[str, Any]]:
291 """
292 Retrieve any constraints or keys (unique, pk, fk, check, index) across
293 one or more columns.
294 """
295 constraints: dict[str, dict[str, Any]] = {}
296 # Get the actual constraint names and columns
297 name_query = """
298 SELECT kc.`constraint_name`, kc.`column_name`,
299 kc.`referenced_table_name`, kc.`referenced_column_name`,
300 c.`constraint_type`
301 FROM
302 information_schema.key_column_usage AS kc,
303 information_schema.table_constraints AS c
304 WHERE
305 kc.table_schema = DATABASE() AND
306 c.table_schema = kc.table_schema AND
307 c.constraint_name = kc.constraint_name AND
308 c.constraint_type != 'CHECK' AND
309 kc.table_name = %s
310 ORDER BY kc.`ordinal_position`
311 """
312 cursor.execute(name_query, [table_name])
313 for constraint, column, ref_table, ref_column, kind in cursor.fetchall():
314 if constraint not in constraints:
315 constraints[constraint] = {
316 "columns": OrderedSet(),
317 "primary_key": kind == "PRIMARY KEY",
318 "unique": kind in {"PRIMARY KEY", "UNIQUE"},
319 "index": False,
320 "check": False,
321 "foreign_key": (ref_table, ref_column) if ref_column else None,
322 }
323 if self.connection.features.supports_index_column_ordering:
324 constraints[constraint]["orders"] = []
325 constraints[constraint]["columns"].add(column)
326 # Add check constraints.
327 if self.connection.features.can_introspect_check_constraints:
328 unnamed_constraints_index = 0
329 columns = {
330 info.name for info in self.get_table_description(cursor, table_name)
331 }
332 if self.connection.mysql_is_mariadb:
333 type_query = """
334 SELECT c.constraint_name, c.check_clause
335 FROM information_schema.check_constraints AS c
336 WHERE
337 c.constraint_schema = DATABASE() AND
338 c.table_name = %s
339 """
340 else:
341 type_query = """
342 SELECT cc.constraint_name, cc.check_clause
343 FROM
344 information_schema.check_constraints AS cc,
345 information_schema.table_constraints AS tc
346 WHERE
347 cc.constraint_schema = DATABASE() AND
348 tc.table_schema = cc.constraint_schema AND
349 cc.constraint_name = tc.constraint_name AND
350 tc.constraint_type = 'CHECK' AND
351 tc.table_name = %s
352 """
353 cursor.execute(type_query, [table_name])
354 for constraint, check_clause in cursor.fetchall():
355 constraint_columns = self._parse_constraint_columns(
356 check_clause, columns
357 )
358 # Ensure uniqueness of unnamed constraints. Unnamed unique
359 # and check columns constraints have the same name as
360 # a column.
361 if set(constraint_columns) == {constraint}:
362 unnamed_constraints_index += 1
363 constraint = f"__unnamed_constraint_{unnamed_constraints_index}__"
364 constraints[constraint] = {
365 "columns": constraint_columns,
366 "primary_key": False,
367 "unique": False,
368 "index": False,
369 "check": True,
370 "foreign_key": None,
371 }
372 # Now add in the indexes
373 cursor.execute(f"SHOW INDEX FROM {self.connection.ops.quote_name(table_name)}")
374 for table, non_unique, index, colseq, column, order, type_ in [
375 x[:6] + (x[10],) for x in cursor.fetchall()
376 ]:
377 if index not in constraints:
378 constraints[index] = {
379 "columns": OrderedSet(),
380 "primary_key": False,
381 "unique": not non_unique,
382 "check": False,
383 "foreign_key": None,
384 }
385 if self.connection.features.supports_index_column_ordering:
386 constraints[index]["orders"] = []
387 constraints[index]["index"] = True
388 constraints[index]["type"] = (
389 Index.suffix if type_ == "BTREE" else type_.lower()
390 )
391 constraints[index]["columns"].add(column)
392 if self.connection.features.supports_index_column_ordering:
393 constraints[index]["orders"].append("DESC" if order == "D" else "ASC")
394 # Convert the sorted sets to lists
395 for constraint in constraints.values():
396 constraint["columns"] = list(constraint["columns"])
397 return constraints