1from collections import namedtuple
2
3import sqlparse
4from MySQLdb.constants import FIELD_TYPE
5
6from plain.models.backends.base.introspection import BaseDatabaseIntrospection
7from plain.models.backends.base.introspection import FieldInfo as BaseFieldInfo
8from plain.models.backends.base.introspection import TableInfo as BaseTableInfo
9from plain.models.indexes import Index
10from plain.utils.datastructures import OrderedSet
11
12FieldInfo = namedtuple(
13 "FieldInfo",
14 BaseFieldInfo._fields + ("extra", "is_unsigned", "has_json_constraint", "comment"),
15)
16InfoLine = namedtuple(
17 "InfoLine",
18 "col_name data_type max_len num_prec num_scale extra column_default "
19 "collation is_unsigned comment",
20)
21TableInfo = namedtuple("TableInfo", BaseTableInfo._fields + ("comment",))
22
23
24class DatabaseIntrospection(BaseDatabaseIntrospection):
25 data_types_reverse = {
26 FIELD_TYPE.BLOB: "TextField",
27 FIELD_TYPE.CHAR: "CharField",
28 FIELD_TYPE.DECIMAL: "DecimalField",
29 FIELD_TYPE.NEWDECIMAL: "DecimalField",
30 FIELD_TYPE.DATE: "DateField",
31 FIELD_TYPE.DATETIME: "DateTimeField",
32 FIELD_TYPE.DOUBLE: "FloatField",
33 FIELD_TYPE.FLOAT: "FloatField",
34 FIELD_TYPE.INT24: "IntegerField",
35 FIELD_TYPE.JSON: "JSONField",
36 FIELD_TYPE.LONG: "IntegerField",
37 FIELD_TYPE.LONGLONG: "BigIntegerField",
38 FIELD_TYPE.SHORT: "SmallIntegerField",
39 FIELD_TYPE.STRING: "CharField",
40 FIELD_TYPE.TIME: "TimeField",
41 FIELD_TYPE.TIMESTAMP: "DateTimeField",
42 FIELD_TYPE.TINY: "IntegerField",
43 FIELD_TYPE.TINY_BLOB: "TextField",
44 FIELD_TYPE.MEDIUM_BLOB: "TextField",
45 FIELD_TYPE.LONG_BLOB: "TextField",
46 FIELD_TYPE.VAR_STRING: "CharField",
47 }
48
49 def get_field_type(self, data_type, description):
50 field_type = super().get_field_type(data_type, description)
51 if "auto_increment" in description.extra:
52 if field_type == "BigIntegerField":
53 return "PrimaryKeyField"
54 if description.is_unsigned:
55 if field_type == "BigIntegerField":
56 return "PositiveBigIntegerField"
57 elif field_type == "IntegerField":
58 return "PositiveIntegerField"
59 elif field_type == "SmallIntegerField":
60 return "PositiveSmallIntegerField"
61 # JSON data type is an alias for LONGTEXT in MariaDB, use check
62 # constraints clauses to introspect JSONField.
63 if description.has_json_constraint:
64 return "JSONField"
65 return field_type
66
67 def get_table_list(self, cursor):
68 """Return a list of table and view names in the current database."""
69 cursor.execute(
70 """
71 SELECT
72 table_name,
73 table_type,
74 table_comment
75 FROM information_schema.tables
76 WHERE table_schema = DATABASE()
77 """
78 )
79 return [
80 TableInfo(row[0], {"BASE TABLE": "t", "VIEW": "v"}.get(row[1]), row[2])
81 for row in cursor.fetchall()
82 ]
83
84 def get_table_description(self, cursor, table_name):
85 """
86 Return a description of the table with the DB-API cursor.description
87 interface."
88 """
89 json_constraints = {}
90 if (
91 self.connection.mysql_is_mariadb
92 and self.connection.features.can_introspect_json_field
93 ):
94 # JSON data type is an alias for LONGTEXT in MariaDB, select
95 # JSON_VALID() constraints to introspect JSONField.
96 cursor.execute(
97 """
98 SELECT c.constraint_name AS column_name
99 FROM information_schema.check_constraints AS c
100 WHERE
101 c.table_name = %s AND
102 LOWER(c.check_clause) =
103 'json_valid(`' + LOWER(c.constraint_name) + '`)' AND
104 c.constraint_schema = DATABASE()
105 """,
106 [table_name],
107 )
108 json_constraints = {row[0] for row in cursor.fetchall()}
109 # A default collation for the given table.
110 cursor.execute(
111 """
112 SELECT table_collation
113 FROM information_schema.tables
114 WHERE table_schema = DATABASE()
115 AND table_name = %s
116 """,
117 [table_name],
118 )
119 row = cursor.fetchone()
120 default_column_collation = row[0] if row else ""
121 # information_schema database gives more accurate results for some figures:
122 # - varchar length returned by cursor.description is an internal length,
123 # not visible length (#5725)
124 # - precision and scale (for decimal fields) (#5014)
125 # - auto_increment is not available in cursor.description
126 cursor.execute(
127 """
128 SELECT
129 column_name, data_type, character_maximum_length,
130 numeric_precision, numeric_scale, extra, column_default,
131 CASE
132 WHEN collation_name = %s THEN NULL
133 ELSE collation_name
134 END AS collation_name,
135 CASE
136 WHEN column_type LIKE '%% unsigned' THEN 1
137 ELSE 0
138 END AS is_unsigned,
139 column_comment
140 FROM information_schema.columns
141 WHERE table_name = %s AND table_schema = DATABASE()
142 """,
143 [default_column_collation, table_name],
144 )
145 field_info = {line[0]: InfoLine(*line) for line in cursor.fetchall()}
146
147 cursor.execute(
148 f"SELECT * FROM {self.connection.ops.quote_name(table_name)} LIMIT 1"
149 )
150
151 def to_int(i):
152 return int(i) if i is not None else i
153
154 fields = []
155 for line in cursor.description:
156 info = field_info[line[0]]
157 fields.append(
158 FieldInfo(
159 *line[:2],
160 to_int(info.max_len) or line[2],
161 to_int(info.max_len) or line[3],
162 to_int(info.num_prec) or line[4],
163 to_int(info.num_scale) or line[5],
164 line[6],
165 info.column_default,
166 info.collation,
167 info.extra,
168 info.is_unsigned,
169 line[0] in json_constraints,
170 info.comment,
171 )
172 )
173 return fields
174
175 def get_sequences(self, cursor, table_name, table_fields=()):
176 for field_info in self.get_table_description(cursor, table_name):
177 if "auto_increment" in field_info.extra:
178 # MySQL allows only one auto-increment column per table.
179 return [{"table": table_name, "column": field_info.name}]
180 return []
181
182 def get_relations(self, cursor, table_name):
183 """
184 Return a dictionary of {field_name: (field_name_other_table, other_table)}
185 representing all foreign keys in the given table.
186 """
187 cursor.execute(
188 """
189 SELECT column_name, referenced_column_name, referenced_table_name
190 FROM information_schema.key_column_usage
191 WHERE table_name = %s
192 AND table_schema = DATABASE()
193 AND referenced_table_name IS NOT NULL
194 AND referenced_column_name IS NOT NULL
195 """,
196 [table_name],
197 )
198 return {
199 field_name: (other_field, other_table)
200 for field_name, other_field, other_table in cursor.fetchall()
201 }
202
203 def get_storage_engine(self, cursor, table_name):
204 """
205 Retrieve the storage engine for a given table. Return the default
206 storage engine if the table doesn't exist.
207 """
208 cursor.execute(
209 """
210 SELECT engine
211 FROM information_schema.tables
212 WHERE
213 table_name = %s AND
214 table_schema = DATABASE()
215 """,
216 [table_name],
217 )
218 result = cursor.fetchone()
219 if not result:
220 return self.connection.features._mysql_storage_engine
221 return result[0]
222
223 def _parse_constraint_columns(self, check_clause, columns):
224 check_columns = OrderedSet()
225 statement = sqlparse.parse(check_clause)[0]
226 tokens = (token for token in statement.flatten() if not token.is_whitespace)
227 for token in tokens:
228 if (
229 token.ttype == sqlparse.tokens.Name
230 and self.connection.ops.quote_name(token.value) == token.value
231 and token.value[1:-1] in columns
232 ):
233 check_columns.add(token.value[1:-1])
234 return check_columns
235
236 def get_constraints(self, cursor, table_name):
237 """
238 Retrieve any constraints or keys (unique, pk, fk, check, index) across
239 one or more columns.
240 """
241 constraints = {}
242 # Get the actual constraint names and columns
243 name_query = """
244 SELECT kc.`constraint_name`, kc.`column_name`,
245 kc.`referenced_table_name`, kc.`referenced_column_name`,
246 c.`constraint_type`
247 FROM
248 information_schema.key_column_usage AS kc,
249 information_schema.table_constraints AS c
250 WHERE
251 kc.table_schema = DATABASE() AND
252 c.table_schema = kc.table_schema AND
253 c.constraint_name = kc.constraint_name AND
254 c.constraint_type != 'CHECK' AND
255 kc.table_name = %s
256 ORDER BY kc.`ordinal_position`
257 """
258 cursor.execute(name_query, [table_name])
259 for constraint, column, ref_table, ref_column, kind in cursor.fetchall():
260 if constraint not in constraints:
261 constraints[constraint] = {
262 "columns": OrderedSet(),
263 "primary_key": kind == "PRIMARY KEY",
264 "unique": kind in {"PRIMARY KEY", "UNIQUE"},
265 "index": False,
266 "check": False,
267 "foreign_key": (ref_table, ref_column) if ref_column else None,
268 }
269 if self.connection.features.supports_index_column_ordering:
270 constraints[constraint]["orders"] = []
271 constraints[constraint]["columns"].add(column)
272 # Add check constraints.
273 if self.connection.features.can_introspect_check_constraints:
274 unnamed_constraints_index = 0
275 columns = {
276 info.name for info in self.get_table_description(cursor, table_name)
277 }
278 if self.connection.mysql_is_mariadb:
279 type_query = """
280 SELECT c.constraint_name, c.check_clause
281 FROM information_schema.check_constraints AS c
282 WHERE
283 c.constraint_schema = DATABASE() AND
284 c.table_name = %s
285 """
286 else:
287 type_query = """
288 SELECT cc.constraint_name, cc.check_clause
289 FROM
290 information_schema.check_constraints AS cc,
291 information_schema.table_constraints AS tc
292 WHERE
293 cc.constraint_schema = DATABASE() AND
294 tc.table_schema = cc.constraint_schema AND
295 cc.constraint_name = tc.constraint_name AND
296 tc.constraint_type = 'CHECK' AND
297 tc.table_name = %s
298 """
299 cursor.execute(type_query, [table_name])
300 for constraint, check_clause in cursor.fetchall():
301 constraint_columns = self._parse_constraint_columns(
302 check_clause, columns
303 )
304 # Ensure uniqueness of unnamed constraints. Unnamed unique
305 # and check columns constraints have the same name as
306 # a column.
307 if set(constraint_columns) == {constraint}:
308 unnamed_constraints_index += 1
309 constraint = f"__unnamed_constraint_{unnamed_constraints_index}__"
310 constraints[constraint] = {
311 "columns": constraint_columns,
312 "primary_key": False,
313 "unique": False,
314 "index": False,
315 "check": True,
316 "foreign_key": None,
317 }
318 # Now add in the indexes
319 cursor.execute(f"SHOW INDEX FROM {self.connection.ops.quote_name(table_name)}")
320 for table, non_unique, index, colseq, column, order, type_ in [
321 x[:6] + (x[10],) for x in cursor.fetchall()
322 ]:
323 if index not in constraints:
324 constraints[index] = {
325 "columns": OrderedSet(),
326 "primary_key": False,
327 "unique": not non_unique,
328 "check": False,
329 "foreign_key": None,
330 }
331 if self.connection.features.supports_index_column_ordering:
332 constraints[index]["orders"] = []
333 constraints[index]["index"] = True
334 constraints[index]["type"] = (
335 Index.suffix if type_ == "BTREE" else type_.lower()
336 )
337 constraints[index]["columns"].add(column)
338 if self.connection.features.supports_index_column_ordering:
339 constraints[index]["orders"].append("DESC" if order == "D" else "ASC")
340 # Convert the sorted sets to lists
341 for constraint in constraints.values():
342 constraint["columns"] = list(constraint["columns"])
343 return constraints