1from collections import namedtuple
2
3import sqlparse
4from MySQLdb.constants import FIELD_TYPE
5
6from plain.models.backends.base.introspection import BaseDatabaseIntrospection
7from plain.models.backends.base.introspection import FieldInfo as BaseFieldInfo
8from plain.models.backends.base.introspection import TableInfo as BaseTableInfo
9from plain.models.indexes import Index
10from plain.utils.datastructures import OrderedSet
11
12FieldInfo = namedtuple(
13 "FieldInfo",
14 BaseFieldInfo._fields + ("extra", "is_unsigned", "has_json_constraint", "comment"),
15)
16InfoLine = namedtuple(
17 "InfoLine",
18 "col_name data_type max_len num_prec num_scale extra column_default "
19 "collation is_unsigned comment",
20)
21TableInfo = namedtuple("TableInfo", BaseTableInfo._fields + ("comment",))
22
23
24class DatabaseIntrospection(BaseDatabaseIntrospection):
25 data_types_reverse = {
26 FIELD_TYPE.BLOB: "TextField",
27 FIELD_TYPE.CHAR: "CharField",
28 FIELD_TYPE.DECIMAL: "DecimalField",
29 FIELD_TYPE.NEWDECIMAL: "DecimalField",
30 FIELD_TYPE.DATE: "DateField",
31 FIELD_TYPE.DATETIME: "DateTimeField",
32 FIELD_TYPE.DOUBLE: "FloatField",
33 FIELD_TYPE.FLOAT: "FloatField",
34 FIELD_TYPE.INT24: "IntegerField",
35 FIELD_TYPE.JSON: "JSONField",
36 FIELD_TYPE.LONG: "IntegerField",
37 FIELD_TYPE.LONGLONG: "BigIntegerField",
38 FIELD_TYPE.SHORT: "SmallIntegerField",
39 FIELD_TYPE.STRING: "CharField",
40 FIELD_TYPE.TIME: "TimeField",
41 FIELD_TYPE.TIMESTAMP: "DateTimeField",
42 FIELD_TYPE.TINY: "IntegerField",
43 FIELD_TYPE.TINY_BLOB: "TextField",
44 FIELD_TYPE.MEDIUM_BLOB: "TextField",
45 FIELD_TYPE.LONG_BLOB: "TextField",
46 FIELD_TYPE.VAR_STRING: "CharField",
47 }
48
49 def get_field_type(self, data_type, description):
50 field_type = super().get_field_type(data_type, description)
51 if "auto_increment" in description.extra:
52 if field_type == "IntegerField":
53 return "AutoField"
54 elif field_type == "BigIntegerField":
55 return "BigAutoField"
56 elif field_type == "SmallIntegerField":
57 return "SmallAutoField"
58 if description.is_unsigned:
59 if field_type == "BigIntegerField":
60 return "PositiveBigIntegerField"
61 elif field_type == "IntegerField":
62 return "PositiveIntegerField"
63 elif field_type == "SmallIntegerField":
64 return "PositiveSmallIntegerField"
65 # JSON data type is an alias for LONGTEXT in MariaDB, use check
66 # constraints clauses to introspect JSONField.
67 if description.has_json_constraint:
68 return "JSONField"
69 return field_type
70
71 def get_table_list(self, cursor):
72 """Return a list of table and view names in the current database."""
73 cursor.execute(
74 """
75 SELECT
76 table_name,
77 table_type,
78 table_comment
79 FROM information_schema.tables
80 WHERE table_schema = DATABASE()
81 """
82 )
83 return [
84 TableInfo(row[0], {"BASE TABLE": "t", "VIEW": "v"}.get(row[1]), row[2])
85 for row in cursor.fetchall()
86 ]
87
88 def get_table_description(self, cursor, table_name):
89 """
90 Return a description of the table with the DB-API cursor.description
91 interface."
92 """
93 json_constraints = {}
94 if (
95 self.connection.mysql_is_mariadb
96 and self.connection.features.can_introspect_json_field
97 ):
98 # JSON data type is an alias for LONGTEXT in MariaDB, select
99 # JSON_VALID() constraints to introspect JSONField.
100 cursor.execute(
101 """
102 SELECT c.constraint_name AS column_name
103 FROM information_schema.check_constraints AS c
104 WHERE
105 c.table_name = %s AND
106 LOWER(c.check_clause) =
107 'json_valid(`' + LOWER(c.constraint_name) + '`)' AND
108 c.constraint_schema = DATABASE()
109 """,
110 [table_name],
111 )
112 json_constraints = {row[0] for row in cursor.fetchall()}
113 # A default collation for the given table.
114 cursor.execute(
115 """
116 SELECT table_collation
117 FROM information_schema.tables
118 WHERE table_schema = DATABASE()
119 AND table_name = %s
120 """,
121 [table_name],
122 )
123 row = cursor.fetchone()
124 default_column_collation = row[0] if row else ""
125 # information_schema database gives more accurate results for some figures:
126 # - varchar length returned by cursor.description is an internal length,
127 # not visible length (#5725)
128 # - precision and scale (for decimal fields) (#5014)
129 # - auto_increment is not available in cursor.description
130 cursor.execute(
131 """
132 SELECT
133 column_name, data_type, character_maximum_length,
134 numeric_precision, numeric_scale, extra, column_default,
135 CASE
136 WHEN collation_name = %s THEN NULL
137 ELSE collation_name
138 END AS collation_name,
139 CASE
140 WHEN column_type LIKE '%% unsigned' THEN 1
141 ELSE 0
142 END AS is_unsigned,
143 column_comment
144 FROM information_schema.columns
145 WHERE table_name = %s AND table_schema = DATABASE()
146 """,
147 [default_column_collation, table_name],
148 )
149 field_info = {line[0]: InfoLine(*line) for line in cursor.fetchall()}
150
151 cursor.execute(
152 "SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)
153 )
154
155 def to_int(i):
156 return int(i) if i is not None else i
157
158 fields = []
159 for line in cursor.description:
160 info = field_info[line[0]]
161 fields.append(
162 FieldInfo(
163 *line[:2],
164 to_int(info.max_len) or line[2],
165 to_int(info.max_len) or line[3],
166 to_int(info.num_prec) or line[4],
167 to_int(info.num_scale) or line[5],
168 line[6],
169 info.column_default,
170 info.collation,
171 info.extra,
172 info.is_unsigned,
173 line[0] in json_constraints,
174 info.comment,
175 )
176 )
177 return fields
178
179 def get_sequences(self, cursor, table_name, table_fields=()):
180 for field_info in self.get_table_description(cursor, table_name):
181 if "auto_increment" in field_info.extra:
182 # MySQL allows only one auto-increment column per table.
183 return [{"table": table_name, "column": field_info.name}]
184 return []
185
186 def get_relations(self, cursor, table_name):
187 """
188 Return a dictionary of {field_name: (field_name_other_table, other_table)}
189 representing all foreign keys in the given table.
190 """
191 cursor.execute(
192 """
193 SELECT column_name, referenced_column_name, referenced_table_name
194 FROM information_schema.key_column_usage
195 WHERE table_name = %s
196 AND table_schema = DATABASE()
197 AND referenced_table_name IS NOT NULL
198 AND referenced_column_name IS NOT NULL
199 """,
200 [table_name],
201 )
202 return {
203 field_name: (other_field, other_table)
204 for field_name, other_field, other_table in cursor.fetchall()
205 }
206
207 def get_storage_engine(self, cursor, table_name):
208 """
209 Retrieve the storage engine for a given table. Return the default
210 storage engine if the table doesn't exist.
211 """
212 cursor.execute(
213 """
214 SELECT engine
215 FROM information_schema.tables
216 WHERE
217 table_name = %s AND
218 table_schema = DATABASE()
219 """,
220 [table_name],
221 )
222 result = cursor.fetchone()
223 if not result:
224 return self.connection.features._mysql_storage_engine
225 return result[0]
226
227 def _parse_constraint_columns(self, check_clause, columns):
228 check_columns = OrderedSet()
229 statement = sqlparse.parse(check_clause)[0]
230 tokens = (token for token in statement.flatten() if not token.is_whitespace)
231 for token in tokens:
232 if (
233 token.ttype == sqlparse.tokens.Name
234 and self.connection.ops.quote_name(token.value) == token.value
235 and token.value[1:-1] in columns
236 ):
237 check_columns.add(token.value[1:-1])
238 return check_columns
239
240 def get_constraints(self, cursor, table_name):
241 """
242 Retrieve any constraints or keys (unique, pk, fk, check, index) across
243 one or more columns.
244 """
245 constraints = {}
246 # Get the actual constraint names and columns
247 name_query = """
248 SELECT kc.`constraint_name`, kc.`column_name`,
249 kc.`referenced_table_name`, kc.`referenced_column_name`,
250 c.`constraint_type`
251 FROM
252 information_schema.key_column_usage AS kc,
253 information_schema.table_constraints AS c
254 WHERE
255 kc.table_schema = DATABASE() AND
256 c.table_schema = kc.table_schema AND
257 c.constraint_name = kc.constraint_name AND
258 c.constraint_type != 'CHECK' AND
259 kc.table_name = %s
260 ORDER BY kc.`ordinal_position`
261 """
262 cursor.execute(name_query, [table_name])
263 for constraint, column, ref_table, ref_column, kind in cursor.fetchall():
264 if constraint not in constraints:
265 constraints[constraint] = {
266 "columns": OrderedSet(),
267 "primary_key": kind == "PRIMARY KEY",
268 "unique": kind in {"PRIMARY KEY", "UNIQUE"},
269 "index": False,
270 "check": False,
271 "foreign_key": (ref_table, ref_column) if ref_column else None,
272 }
273 if self.connection.features.supports_index_column_ordering:
274 constraints[constraint]["orders"] = []
275 constraints[constraint]["columns"].add(column)
276 # Add check constraints.
277 if self.connection.features.can_introspect_check_constraints:
278 unnamed_constraints_index = 0
279 columns = {
280 info.name for info in self.get_table_description(cursor, table_name)
281 }
282 if self.connection.mysql_is_mariadb:
283 type_query = """
284 SELECT c.constraint_name, c.check_clause
285 FROM information_schema.check_constraints AS c
286 WHERE
287 c.constraint_schema = DATABASE() AND
288 c.table_name = %s
289 """
290 else:
291 type_query = """
292 SELECT cc.constraint_name, cc.check_clause
293 FROM
294 information_schema.check_constraints AS cc,
295 information_schema.table_constraints AS tc
296 WHERE
297 cc.constraint_schema = DATABASE() AND
298 tc.table_schema = cc.constraint_schema AND
299 cc.constraint_name = tc.constraint_name AND
300 tc.constraint_type = 'CHECK' AND
301 tc.table_name = %s
302 """
303 cursor.execute(type_query, [table_name])
304 for constraint, check_clause in cursor.fetchall():
305 constraint_columns = self._parse_constraint_columns(
306 check_clause, columns
307 )
308 # Ensure uniqueness of unnamed constraints. Unnamed unique
309 # and check columns constraints have the same name as
310 # a column.
311 if set(constraint_columns) == {constraint}:
312 unnamed_constraints_index += 1
313 constraint = "__unnamed_constraint_%s__" % unnamed_constraints_index
314 constraints[constraint] = {
315 "columns": constraint_columns,
316 "primary_key": False,
317 "unique": False,
318 "index": False,
319 "check": True,
320 "foreign_key": None,
321 }
322 # Now add in the indexes
323 cursor.execute(
324 "SHOW INDEX FROM %s" % self.connection.ops.quote_name(table_name)
325 )
326 for table, non_unique, index, colseq, column, order, type_ in [
327 x[:6] + (x[10],) for x in cursor.fetchall()
328 ]:
329 if index not in constraints:
330 constraints[index] = {
331 "columns": OrderedSet(),
332 "primary_key": False,
333 "unique": not non_unique,
334 "check": False,
335 "foreign_key": None,
336 }
337 if self.connection.features.supports_index_column_ordering:
338 constraints[index]["orders"] = []
339 constraints[index]["index"] = True
340 constraints[index]["type"] = (
341 Index.suffix if type_ == "BTREE" else type_.lower()
342 )
343 constraints[index]["columns"].add(column)
344 if self.connection.features.supports_index_column_ordering:
345 constraints[index]["orders"].append("DESC" if order == "D" else "ASC")
346 # Convert the sorted sets to lists
347 for constraint in constraints.values():
348 constraint["columns"] = list(constraint["columns"])
349 return constraints