1from collections import namedtuple
2
3import sqlparse
4
5from plain.models import Index
6from plain.models.backends.base.introspection import (
7 BaseDatabaseIntrospection,
8 TableInfo,
9)
10from plain.models.backends.base.introspection import FieldInfo as BaseFieldInfo
11from plain.models.db import DatabaseError
12from plain.utils.regex_helper import _lazy_re_compile
13
14FieldInfo = namedtuple(
15 "FieldInfo", BaseFieldInfo._fields + ("pk", "has_json_constraint")
16)
17
18field_size_re = _lazy_re_compile(r"^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$")
19
20
21def get_field_size(name):
22 """Extract the size number from a "varchar(11)" type name"""
23 m = field_size_re.search(name)
24 return int(m[1]) if m else None
25
26
27# This light wrapper "fakes" a dictionary interface, because some SQLite data
28# types include variables in them -- e.g. "varchar(30)" -- and can't be matched
29# as a simple dictionary lookup.
30class FlexibleFieldLookupDict:
31 # Maps SQL types to Plain Field types. Some of the SQL types have multiple
32 # entries here because SQLite allows for anything and doesn't normalize the
33 # field type; it uses whatever was given.
34 base_data_types_reverse = {
35 "bool": "BooleanField",
36 "boolean": "BooleanField",
37 "smallint": "SmallIntegerField",
38 "smallint unsigned": "PositiveSmallIntegerField",
39 "smallinteger": "SmallIntegerField",
40 "int": "IntegerField",
41 "integer": "IntegerField",
42 "bigint": "BigIntegerField",
43 "integer unsigned": "PositiveIntegerField",
44 "bigint unsigned": "PositiveBigIntegerField",
45 "decimal": "DecimalField",
46 "real": "FloatField",
47 "text": "TextField",
48 "char": "CharField",
49 "varchar": "CharField",
50 "blob": "BinaryField",
51 "date": "DateField",
52 "datetime": "DateTimeField",
53 "time": "TimeField",
54 }
55
56 def __getitem__(self, key):
57 key = key.lower().split("(", 1)[0].strip()
58 return self.base_data_types_reverse[key]
59
60
61class DatabaseIntrospection(BaseDatabaseIntrospection):
62 data_types_reverse = FlexibleFieldLookupDict()
63
64 def get_field_type(self, data_type, description):
65 field_type = super().get_field_type(data_type, description)
66 if description.pk and field_type in {
67 "BigIntegerField",
68 "IntegerField",
69 "SmallIntegerField",
70 }:
71 # No support for BigAutoField or SmallAutoField as SQLite treats
72 # all integer primary keys as signed 64-bit integers.
73 return "AutoField"
74 if description.has_json_constraint:
75 return "JSONField"
76 return field_type
77
78 def get_table_list(self, cursor):
79 """Return a list of table and view names in the current database."""
80 # Skip the sqlite_sequence system table used for autoincrement key
81 # generation.
82 cursor.execute(
83 """
84 SELECT name, type FROM sqlite_master
85 WHERE type in ('table', 'view') AND NOT name='sqlite_sequence'
86 ORDER BY name"""
87 )
88 return [TableInfo(row[0], row[1][0]) for row in cursor.fetchall()]
89
90 def get_table_description(self, cursor, table_name):
91 """
92 Return a description of the table with the DB-API cursor.description
93 interface.
94 """
95 cursor.execute(
96 "PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name)
97 )
98 table_info = cursor.fetchall()
99 if not table_info:
100 raise DatabaseError(f"Table {table_name} does not exist (empty pragma).")
101 collations = self._get_column_collations(cursor, table_name)
102 json_columns = set()
103 if self.connection.features.can_introspect_json_field:
104 for line in table_info:
105 column = line[1]
106 json_constraint_sql = '%%json_valid("%s")%%' % column
107 has_json_constraint = cursor.execute(
108 """
109 SELECT sql
110 FROM sqlite_master
111 WHERE
112 type = 'table' AND
113 name = %s AND
114 sql LIKE %s
115 """,
116 [table_name, json_constraint_sql],
117 ).fetchone()
118 if has_json_constraint:
119 json_columns.add(column)
120 return [
121 FieldInfo(
122 name,
123 data_type,
124 get_field_size(data_type),
125 None,
126 None,
127 None,
128 not notnull,
129 default,
130 collations.get(name),
131 pk == 1,
132 name in json_columns,
133 )
134 for cid, name, data_type, notnull, default, pk in table_info
135 ]
136
137 def get_sequences(self, cursor, table_name, table_fields=()):
138 pk_col = self.get_primary_key_column(cursor, table_name)
139 return [{"table": table_name, "column": pk_col}]
140
141 def get_relations(self, cursor, table_name):
142 """
143 Return a dictionary of {column_name: (ref_column_name, ref_table_name)}
144 representing all foreign keys in the given table.
145 """
146 cursor.execute(
147 "PRAGMA foreign_key_list(%s)" % self.connection.ops.quote_name(table_name)
148 )
149 return {
150 column_name: (ref_column_name, ref_table_name)
151 for (
152 _,
153 _,
154 ref_table_name,
155 column_name,
156 ref_column_name,
157 *_,
158 ) in cursor.fetchall()
159 }
160
161 def get_primary_key_columns(self, cursor, table_name):
162 cursor.execute(
163 "PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name)
164 )
165 return [name for _, name, *_, pk in cursor.fetchall() if pk]
166
167 def _parse_column_or_constraint_definition(self, tokens, columns):
168 token = None
169 is_constraint_definition = None
170 field_name = None
171 constraint_name = None
172 unique = False
173 unique_columns = []
174 check = False
175 check_columns = []
176 braces_deep = 0
177 for token in tokens:
178 if token.match(sqlparse.tokens.Punctuation, "("):
179 braces_deep += 1
180 elif token.match(sqlparse.tokens.Punctuation, ")"):
181 braces_deep -= 1
182 if braces_deep < 0:
183 # End of columns and constraints for table definition.
184 break
185 elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ","):
186 # End of current column or constraint definition.
187 break
188 # Detect column or constraint definition by first token.
189 if is_constraint_definition is None:
190 is_constraint_definition = token.match(
191 sqlparse.tokens.Keyword, "CONSTRAINT"
192 )
193 if is_constraint_definition:
194 continue
195 if is_constraint_definition:
196 # Detect constraint name by second token.
197 if constraint_name is None:
198 if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
199 constraint_name = token.value
200 elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
201 constraint_name = token.value[1:-1]
202 # Start constraint columns parsing after UNIQUE keyword.
203 if token.match(sqlparse.tokens.Keyword, "UNIQUE"):
204 unique = True
205 unique_braces_deep = braces_deep
206 elif unique:
207 if unique_braces_deep == braces_deep:
208 if unique_columns:
209 # Stop constraint parsing.
210 unique = False
211 continue
212 if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
213 unique_columns.append(token.value)
214 elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
215 unique_columns.append(token.value[1:-1])
216 else:
217 # Detect field name by first token.
218 if field_name is None:
219 if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
220 field_name = token.value
221 elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
222 field_name = token.value[1:-1]
223 if token.match(sqlparse.tokens.Keyword, "UNIQUE"):
224 unique_columns = [field_name]
225 # Start constraint columns parsing after CHECK keyword.
226 if token.match(sqlparse.tokens.Keyword, "CHECK"):
227 check = True
228 check_braces_deep = braces_deep
229 elif check:
230 if check_braces_deep == braces_deep:
231 if check_columns:
232 # Stop constraint parsing.
233 check = False
234 continue
235 if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
236 if token.value in columns:
237 check_columns.append(token.value)
238 elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
239 if token.value[1:-1] in columns:
240 check_columns.append(token.value[1:-1])
241 unique_constraint = (
242 {
243 "unique": True,
244 "columns": unique_columns,
245 "primary_key": False,
246 "foreign_key": None,
247 "check": False,
248 "index": False,
249 }
250 if unique_columns
251 else None
252 )
253 check_constraint = (
254 {
255 "check": True,
256 "columns": check_columns,
257 "primary_key": False,
258 "unique": False,
259 "foreign_key": None,
260 "index": False,
261 }
262 if check_columns
263 else None
264 )
265 return constraint_name, unique_constraint, check_constraint, token
266
267 def _parse_table_constraints(self, sql, columns):
268 # Check constraint parsing is based of SQLite syntax diagram.
269 # https://www.sqlite.org/syntaxdiagrams.html#table-constraint
270 statement = sqlparse.parse(sql)[0]
271 constraints = {}
272 unnamed_constrains_index = 0
273 tokens = (token for token in statement.flatten() if not token.is_whitespace)
274 # Go to columns and constraint definition
275 for token in tokens:
276 if token.match(sqlparse.tokens.Punctuation, "("):
277 break
278 # Parse columns and constraint definition
279 while True:
280 (
281 constraint_name,
282 unique,
283 check,
284 end_token,
285 ) = self._parse_column_or_constraint_definition(tokens, columns)
286 if unique:
287 if constraint_name:
288 constraints[constraint_name] = unique
289 else:
290 unnamed_constrains_index += 1
291 constraints[
292 "__unnamed_constraint_%s__" % unnamed_constrains_index
293 ] = unique
294 if check:
295 if constraint_name:
296 constraints[constraint_name] = check
297 else:
298 unnamed_constrains_index += 1
299 constraints[
300 "__unnamed_constraint_%s__" % unnamed_constrains_index
301 ] = check
302 if end_token.match(sqlparse.tokens.Punctuation, ")"):
303 break
304 return constraints
305
306 def get_constraints(self, cursor, table_name):
307 """
308 Retrieve any constraints or keys (unique, pk, fk, check, index) across
309 one or more columns.
310 """
311 constraints = {}
312 # Find inline check constraints.
313 try:
314 table_schema = cursor.execute(
315 "SELECT sql FROM sqlite_master WHERE type='table' and name={}".format(
316 self.connection.ops.quote_name(table_name)
317 )
318 ).fetchone()[0]
319 except TypeError:
320 # table_name is a view.
321 pass
322 else:
323 columns = {
324 info.name for info in self.get_table_description(cursor, table_name)
325 }
326 constraints.update(self._parse_table_constraints(table_schema, columns))
327
328 # Get the index info
329 cursor.execute(
330 "PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name)
331 )
332 for row in cursor.fetchall():
333 # SQLite 3.8.9+ has 5 columns, however older versions only give 3
334 # columns. Discard last 2 columns if there.
335 number, index, unique = row[:3]
336 cursor.execute(
337 "SELECT sql FROM sqlite_master "
338 "WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index)
339 )
340 # There's at most one row.
341 (sql,) = cursor.fetchone() or (None,)
342 # Inline constraints are already detected in
343 # _parse_table_constraints(). The reasons to avoid fetching inline
344 # constraints from `PRAGMA index_list` are:
345 # - Inline constraints can have a different name and information
346 # than what `PRAGMA index_list` gives.
347 # - Not all inline constraints may appear in `PRAGMA index_list`.
348 if not sql:
349 # An inline constraint
350 continue
351 # Get the index info for that index
352 cursor.execute(
353 "PRAGMA index_info(%s)" % self.connection.ops.quote_name(index)
354 )
355 for index_rank, column_rank, column in cursor.fetchall():
356 if index not in constraints:
357 constraints[index] = {
358 "columns": [],
359 "primary_key": False,
360 "unique": bool(unique),
361 "foreign_key": None,
362 "check": False,
363 "index": True,
364 }
365 constraints[index]["columns"].append(column)
366 # Add type and column orders for indexes
367 if constraints[index]["index"]:
368 # SQLite doesn't support any index type other than b-tree
369 constraints[index]["type"] = Index.suffix
370 orders = self._get_index_columns_orders(sql)
371 if orders is not None:
372 constraints[index]["orders"] = orders
373 # Get the PK
374 pk_columns = self.get_primary_key_columns(cursor, table_name)
375 if pk_columns:
376 # SQLite doesn't actually give a name to the PK constraint,
377 # so we invent one. This is fine, as the SQLite backend never
378 # deletes PK constraints by name, as you can't delete constraints
379 # in SQLite; we remake the table with a new PK instead.
380 constraints["__primary__"] = {
381 "columns": pk_columns,
382 "primary_key": True,
383 "unique": False, # It's not actually a unique constraint.
384 "foreign_key": None,
385 "check": False,
386 "index": False,
387 }
388 relations = enumerate(self.get_relations(cursor, table_name).items())
389 constraints.update(
390 {
391 f"fk_{index}": {
392 "columns": [column_name],
393 "primary_key": False,
394 "unique": False,
395 "foreign_key": (ref_table_name, ref_column_name),
396 "check": False,
397 "index": False,
398 }
399 for index, (column_name, (ref_column_name, ref_table_name)) in relations
400 }
401 )
402 return constraints
403
404 def _get_index_columns_orders(self, sql):
405 tokens = sqlparse.parse(sql)[0]
406 for token in tokens:
407 if isinstance(token, sqlparse.sql.Parenthesis):
408 columns = str(token).strip("()").split(", ")
409 return ["DESC" if info.endswith("DESC") else "ASC" for info in columns]
410 return None
411
412 def _get_column_collations(self, cursor, table_name):
413 row = cursor.execute(
414 """
415 SELECT sql
416 FROM sqlite_master
417 WHERE type = 'table' AND name = %s
418 """,
419 [table_name],
420 ).fetchone()
421 if not row:
422 return {}
423
424 sql = row[0]
425 columns = str(sqlparse.parse(sql)[0][-1]).strip("()").split(", ")
426 collations = {}
427 for column in columns:
428 tokens = column[1:].split()
429 column_name = tokens[0].strip('"')
430 for index, token in enumerate(tokens):
431 if token == "COLLATE":
432 collation = tokens[index + 1]
433 break
434 else:
435 collation = None
436 collations[column_name] = collation
437 return collations