Plain is headed towards 1.0! Subscribe for development updates →

  1from collections import namedtuple
  2
  3import sqlparse
  4
  5from plain.models import Index
  6from plain.models.backends.base.introspection import (
  7    BaseDatabaseIntrospection,
  8    TableInfo,
  9)
 10from plain.models.backends.base.introspection import FieldInfo as BaseFieldInfo
 11from plain.models.db import DatabaseError
 12from plain.utils.regex_helper import _lazy_re_compile
 13
 14FieldInfo = namedtuple(
 15    "FieldInfo", BaseFieldInfo._fields + ("pk", "has_json_constraint")
 16)
 17
 18field_size_re = _lazy_re_compile(r"^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$")
 19
 20
 21def get_field_size(name):
 22    """Extract the size number from a "varchar(11)" type name"""
 23    m = field_size_re.search(name)
 24    return int(m[1]) if m else None
 25
 26
 27# This light wrapper "fakes" a dictionary interface, because some SQLite data
 28# types include variables in them -- e.g. "varchar(30)" -- and can't be matched
 29# as a simple dictionary lookup.
 30class FlexibleFieldLookupDict:
 31    # Maps SQL types to Plain Field types. Some of the SQL types have multiple
 32    # entries here because SQLite allows for anything and doesn't normalize the
 33    # field type; it uses whatever was given.
 34    base_data_types_reverse = {
 35        "bool": "BooleanField",
 36        "boolean": "BooleanField",
 37        "smallint": "SmallIntegerField",
 38        "smallint unsigned": "PositiveSmallIntegerField",
 39        "smallinteger": "SmallIntegerField",
 40        "int": "IntegerField",
 41        "integer": "IntegerField",
 42        "bigint": "BigIntegerField",
 43        "integer unsigned": "PositiveIntegerField",
 44        "bigint unsigned": "PositiveBigIntegerField",
 45        "decimal": "DecimalField",
 46        "real": "FloatField",
 47        "text": "TextField",
 48        "char": "CharField",
 49        "varchar": "CharField",
 50        "blob": "BinaryField",
 51        "date": "DateField",
 52        "datetime": "DateTimeField",
 53        "time": "TimeField",
 54    }
 55
 56    def __getitem__(self, key):
 57        key = key.lower().split("(", 1)[0].strip()
 58        return self.base_data_types_reverse[key]
 59
 60
 61class DatabaseIntrospection(BaseDatabaseIntrospection):
 62    data_types_reverse = FlexibleFieldLookupDict()
 63
 64    def get_field_type(self, data_type, description):
 65        field_type = super().get_field_type(data_type, description)
 66        if description.pk and field_type in {
 67            "BigIntegerField",
 68            "IntegerField",
 69            "SmallIntegerField",
 70        }:
 71            # No support for BigAutoField or SmallAutoField as SQLite treats
 72            # all integer primary keys as signed 64-bit integers.
 73            return "AutoField"
 74        if description.has_json_constraint:
 75            return "JSONField"
 76        return field_type
 77
 78    def get_table_list(self, cursor):
 79        """Return a list of table and view names in the current database."""
 80        # Skip the sqlite_sequence system table used for autoincrement key
 81        # generation.
 82        cursor.execute(
 83            """
 84            SELECT name, type FROM sqlite_master
 85            WHERE type in ('table', 'view') AND NOT name='sqlite_sequence'
 86            ORDER BY name"""
 87        )
 88        return [TableInfo(row[0], row[1][0]) for row in cursor.fetchall()]
 89
 90    def get_table_description(self, cursor, table_name):
 91        """
 92        Return a description of the table with the DB-API cursor.description
 93        interface.
 94        """
 95        cursor.execute(
 96            "PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name)
 97        )
 98        table_info = cursor.fetchall()
 99        if not table_info:
100            raise DatabaseError(f"Table {table_name} does not exist (empty pragma).")
101        collations = self._get_column_collations(cursor, table_name)
102        json_columns = set()
103        if self.connection.features.can_introspect_json_field:
104            for line in table_info:
105                column = line[1]
106                json_constraint_sql = '%%json_valid("%s")%%' % column
107                has_json_constraint = cursor.execute(
108                    """
109                    SELECT sql
110                    FROM sqlite_master
111                    WHERE
112                        type = 'table' AND
113                        name = %s AND
114                        sql LIKE %s
115                """,
116                    [table_name, json_constraint_sql],
117                ).fetchone()
118                if has_json_constraint:
119                    json_columns.add(column)
120        return [
121            FieldInfo(
122                name,
123                data_type,
124                get_field_size(data_type),
125                None,
126                None,
127                None,
128                not notnull,
129                default,
130                collations.get(name),
131                pk == 1,
132                name in json_columns,
133            )
134            for cid, name, data_type, notnull, default, pk in table_info
135        ]
136
137    def get_sequences(self, cursor, table_name, table_fields=()):
138        pk_col = self.get_primary_key_column(cursor, table_name)
139        return [{"table": table_name, "column": pk_col}]
140
141    def get_relations(self, cursor, table_name):
142        """
143        Return a dictionary of {column_name: (ref_column_name, ref_table_name)}
144        representing all foreign keys in the given table.
145        """
146        cursor.execute(
147            "PRAGMA foreign_key_list(%s)" % self.connection.ops.quote_name(table_name)
148        )
149        return {
150            column_name: (ref_column_name, ref_table_name)
151            for (
152                _,
153                _,
154                ref_table_name,
155                column_name,
156                ref_column_name,
157                *_,
158            ) in cursor.fetchall()
159        }
160
161    def get_primary_key_columns(self, cursor, table_name):
162        cursor.execute(
163            "PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name)
164        )
165        return [name for _, name, *_, pk in cursor.fetchall() if pk]
166
167    def _parse_column_or_constraint_definition(self, tokens, columns):
168        token = None
169        is_constraint_definition = None
170        field_name = None
171        constraint_name = None
172        unique = False
173        unique_columns = []
174        check = False
175        check_columns = []
176        braces_deep = 0
177        for token in tokens:
178            if token.match(sqlparse.tokens.Punctuation, "("):
179                braces_deep += 1
180            elif token.match(sqlparse.tokens.Punctuation, ")"):
181                braces_deep -= 1
182                if braces_deep < 0:
183                    # End of columns and constraints for table definition.
184                    break
185            elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ","):
186                # End of current column or constraint definition.
187                break
188            # Detect column or constraint definition by first token.
189            if is_constraint_definition is None:
190                is_constraint_definition = token.match(
191                    sqlparse.tokens.Keyword, "CONSTRAINT"
192                )
193                if is_constraint_definition:
194                    continue
195            if is_constraint_definition:
196                # Detect constraint name by second token.
197                if constraint_name is None:
198                    if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
199                        constraint_name = token.value
200                    elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
201                        constraint_name = token.value[1:-1]
202                # Start constraint columns parsing after UNIQUE keyword.
203                if token.match(sqlparse.tokens.Keyword, "UNIQUE"):
204                    unique = True
205                    unique_braces_deep = braces_deep
206                elif unique:
207                    if unique_braces_deep == braces_deep:
208                        if unique_columns:
209                            # Stop constraint parsing.
210                            unique = False
211                        continue
212                    if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
213                        unique_columns.append(token.value)
214                    elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
215                        unique_columns.append(token.value[1:-1])
216            else:
217                # Detect field name by first token.
218                if field_name is None:
219                    if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
220                        field_name = token.value
221                    elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
222                        field_name = token.value[1:-1]
223                if token.match(sqlparse.tokens.Keyword, "UNIQUE"):
224                    unique_columns = [field_name]
225            # Start constraint columns parsing after CHECK keyword.
226            if token.match(sqlparse.tokens.Keyword, "CHECK"):
227                check = True
228                check_braces_deep = braces_deep
229            elif check:
230                if check_braces_deep == braces_deep:
231                    if check_columns:
232                        # Stop constraint parsing.
233                        check = False
234                    continue
235                if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
236                    if token.value in columns:
237                        check_columns.append(token.value)
238                elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
239                    if token.value[1:-1] in columns:
240                        check_columns.append(token.value[1:-1])
241        unique_constraint = (
242            {
243                "unique": True,
244                "columns": unique_columns,
245                "primary_key": False,
246                "foreign_key": None,
247                "check": False,
248                "index": False,
249            }
250            if unique_columns
251            else None
252        )
253        check_constraint = (
254            {
255                "check": True,
256                "columns": check_columns,
257                "primary_key": False,
258                "unique": False,
259                "foreign_key": None,
260                "index": False,
261            }
262            if check_columns
263            else None
264        )
265        return constraint_name, unique_constraint, check_constraint, token
266
267    def _parse_table_constraints(self, sql, columns):
268        # Check constraint parsing is based of SQLite syntax diagram.
269        # https://www.sqlite.org/syntaxdiagrams.html#table-constraint
270        statement = sqlparse.parse(sql)[0]
271        constraints = {}
272        unnamed_constrains_index = 0
273        tokens = (token for token in statement.flatten() if not token.is_whitespace)
274        # Go to columns and constraint definition
275        for token in tokens:
276            if token.match(sqlparse.tokens.Punctuation, "("):
277                break
278        # Parse columns and constraint definition
279        while True:
280            (
281                constraint_name,
282                unique,
283                check,
284                end_token,
285            ) = self._parse_column_or_constraint_definition(tokens, columns)
286            if unique:
287                if constraint_name:
288                    constraints[constraint_name] = unique
289                else:
290                    unnamed_constrains_index += 1
291                    constraints[
292                        "__unnamed_constraint_%s__" % unnamed_constrains_index
293                    ] = unique
294            if check:
295                if constraint_name:
296                    constraints[constraint_name] = check
297                else:
298                    unnamed_constrains_index += 1
299                    constraints[
300                        "__unnamed_constraint_%s__" % unnamed_constrains_index
301                    ] = check
302            if end_token.match(sqlparse.tokens.Punctuation, ")"):
303                break
304        return constraints
305
306    def get_constraints(self, cursor, table_name):
307        """
308        Retrieve any constraints or keys (unique, pk, fk, check, index) across
309        one or more columns.
310        """
311        constraints = {}
312        # Find inline check constraints.
313        try:
314            table_schema = cursor.execute(
315                "SELECT sql FROM sqlite_master WHERE type='table' and name={}".format(
316                    self.connection.ops.quote_name(table_name)
317                )
318            ).fetchone()[0]
319        except TypeError:
320            # table_name is a view.
321            pass
322        else:
323            columns = {
324                info.name for info in self.get_table_description(cursor, table_name)
325            }
326            constraints.update(self._parse_table_constraints(table_schema, columns))
327
328        # Get the index info
329        cursor.execute(
330            "PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name)
331        )
332        for row in cursor.fetchall():
333            # SQLite 3.8.9+ has 5 columns, however older versions only give 3
334            # columns. Discard last 2 columns if there.
335            number, index, unique = row[:3]
336            cursor.execute(
337                "SELECT sql FROM sqlite_master "
338                "WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index)
339            )
340            # There's at most one row.
341            (sql,) = cursor.fetchone() or (None,)
342            # Inline constraints are already detected in
343            # _parse_table_constraints(). The reasons to avoid fetching inline
344            # constraints from `PRAGMA index_list` are:
345            # - Inline constraints can have a different name and information
346            #   than what `PRAGMA index_list` gives.
347            # - Not all inline constraints may appear in `PRAGMA index_list`.
348            if not sql:
349                # An inline constraint
350                continue
351            # Get the index info for that index
352            cursor.execute(
353                "PRAGMA index_info(%s)" % self.connection.ops.quote_name(index)
354            )
355            for index_rank, column_rank, column in cursor.fetchall():
356                if index not in constraints:
357                    constraints[index] = {
358                        "columns": [],
359                        "primary_key": False,
360                        "unique": bool(unique),
361                        "foreign_key": None,
362                        "check": False,
363                        "index": True,
364                    }
365                constraints[index]["columns"].append(column)
366            # Add type and column orders for indexes
367            if constraints[index]["index"]:
368                # SQLite doesn't support any index type other than b-tree
369                constraints[index]["type"] = Index.suffix
370                orders = self._get_index_columns_orders(sql)
371                if orders is not None:
372                    constraints[index]["orders"] = orders
373        # Get the PK
374        pk_columns = self.get_primary_key_columns(cursor, table_name)
375        if pk_columns:
376            # SQLite doesn't actually give a name to the PK constraint,
377            # so we invent one. This is fine, as the SQLite backend never
378            # deletes PK constraints by name, as you can't delete constraints
379            # in SQLite; we remake the table with a new PK instead.
380            constraints["__primary__"] = {
381                "columns": pk_columns,
382                "primary_key": True,
383                "unique": False,  # It's not actually a unique constraint.
384                "foreign_key": None,
385                "check": False,
386                "index": False,
387            }
388        relations = enumerate(self.get_relations(cursor, table_name).items())
389        constraints.update(
390            {
391                f"fk_{index}": {
392                    "columns": [column_name],
393                    "primary_key": False,
394                    "unique": False,
395                    "foreign_key": (ref_table_name, ref_column_name),
396                    "check": False,
397                    "index": False,
398                }
399                for index, (column_name, (ref_column_name, ref_table_name)) in relations
400            }
401        )
402        return constraints
403
404    def _get_index_columns_orders(self, sql):
405        tokens = sqlparse.parse(sql)[0]
406        for token in tokens:
407            if isinstance(token, sqlparse.sql.Parenthesis):
408                columns = str(token).strip("()").split(", ")
409                return ["DESC" if info.endswith("DESC") else "ASC" for info in columns]
410        return None
411
412    def _get_column_collations(self, cursor, table_name):
413        row = cursor.execute(
414            """
415            SELECT sql
416            FROM sqlite_master
417            WHERE type = 'table' AND name = %s
418        """,
419            [table_name],
420        ).fetchone()
421        if not row:
422            return {}
423
424        sql = row[0]
425        columns = str(sqlparse.parse(sql)[0][-1]).strip("()").split(", ")
426        collations = {}
427        for column in columns:
428            tokens = column[1:].split()
429            column_name = tokens[0].strip('"')
430            for index, token in enumerate(tokens):
431                if token == "COLLATE":
432                    collation = tokens[index + 1]
433                    break
434            else:
435                collation = None
436            collations[column_name] = collation
437        return collations