Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from typing import Any, NamedTuple
  4
  5from plain.models.backends.base.introspection import BaseDatabaseIntrospection
  6from plain.models.backends.utils import CursorWrapper
  7from plain.models.indexes import Index
  8
  9
 10class FieldInfo(NamedTuple):
 11    """PostgreSQL-specific FieldInfo extending base with autofield and comment."""
 12
 13    # Fields from BaseFieldInfo
 14    name: str
 15    type_code: Any
 16    display_size: int | None
 17    internal_size: int | None
 18    precision: int | None
 19    scale: int | None
 20    null_ok: bool | None
 21    default: Any
 22    collation: str | None
 23    # PostgreSQL-specific extensions
 24    is_autofield: bool
 25    comment: str | None
 26
 27
 28class TableInfo(NamedTuple):
 29    """PostgreSQL-specific TableInfo extending base with comment support."""
 30
 31    # Fields from BaseTableInfo
 32    name: str
 33    type: str
 34    # PostgreSQL-specific extension
 35    comment: str | None
 36
 37
 38class DatabaseIntrospection(BaseDatabaseIntrospection):
 39    # Maps type codes to Plain Field types.
 40    data_types_reverse = {
 41        16: "BooleanField",
 42        17: "BinaryField",
 43        20: "BigIntegerField",
 44        21: "SmallIntegerField",
 45        23: "IntegerField",
 46        25: "TextField",
 47        700: "FloatField",
 48        701: "FloatField",
 49        869: "GenericIPAddressField",
 50        1042: "CharField",  # blank-padded
 51        1043: "CharField",
 52        1082: "DateField",
 53        1083: "TimeField",
 54        1114: "DateTimeField",
 55        1184: "DateTimeField",
 56        1186: "DurationField",
 57        1266: "TimeField",
 58        1700: "DecimalField",
 59        2950: "UUIDField",
 60        3802: "JSONField",
 61    }
 62    # A hook for subclasses.
 63    index_default_access_method = "btree"
 64
 65    ignored_tables: list[str] = []
 66
 67    def get_field_type(self, data_type: Any, description: Any) -> str:
 68        field_type = super().get_field_type(data_type, description)
 69        if description.is_autofield or (
 70            # Required for pre-Plain 4.1 serial columns.
 71            description.default and "nextval" in description.default
 72        ):
 73            if field_type == "BigIntegerField":
 74                return "PrimaryKeyField"
 75        return field_type
 76
 77    def get_table_list(self, cursor: CursorWrapper) -> list[TableInfo]:
 78        """Return a list of table and view names in the current database."""
 79        cursor.execute(
 80            """
 81            SELECT
 82                c.relname,
 83                CASE
 84                    WHEN c.relispartition THEN 'p'
 85                    WHEN c.relkind IN ('m', 'v') THEN 'v'
 86                    ELSE 't'
 87                END,
 88                obj_description(c.oid, 'pg_class')
 89            FROM pg_catalog.pg_class c
 90            LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
 91            WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
 92                AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
 93                AND pg_catalog.pg_table_is_visible(c.oid)
 94        """
 95        )
 96        return [
 97            TableInfo(*row)
 98            for row in cursor.fetchall()
 99            if row[0] not in self.ignored_tables
100        ]
101
102    def get_table_description(
103        self, cursor: CursorWrapper, table_name: str
104    ) -> list[FieldInfo]:
105        """
106        Return a description of the table with the DB-API cursor.description
107        interface.
108        """
109        # Query the pg_catalog tables as cursor.description does not reliably
110        # return the nullable property and information_schema.columns does not
111        # contain details of materialized views.
112        cursor.execute(
113            """
114            SELECT
115                a.attname AS column_name,
116                NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable,
117                pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
118                CASE WHEN collname = 'default' THEN NULL ELSE collname END AS collation,
119                a.attidentity != '' AS is_autofield,
120                col_description(a.attrelid, a.attnum) AS column_comment
121            FROM pg_attribute a
122            LEFT JOIN pg_attrdef ad ON a.attrelid = ad.adrelid AND a.attnum = ad.adnum
123            LEFT JOIN pg_collation co ON a.attcollation = co.oid
124            JOIN pg_type t ON a.atttypid = t.oid
125            JOIN pg_class c ON a.attrelid = c.oid
126            JOIN pg_namespace n ON c.relnamespace = n.oid
127            WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
128                AND c.relname = %s
129                AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
130                AND pg_catalog.pg_table_is_visible(c.oid)
131        """,
132            [table_name],
133        )
134        field_map = {line[0]: line[1:] for line in cursor.fetchall()}
135        cursor.execute(
136            f"SELECT * FROM {self.connection.ops.quote_name(table_name)} LIMIT 1"
137        )
138        return [
139            FieldInfo(
140                line.name,
141                line.type_code,
142                line.internal_size if line.display_size is None else line.display_size,
143                line.internal_size,
144                line.precision,
145                line.scale,
146                *field_map[line.name],
147            )
148            for line in cursor.description
149        ]
150
151    def get_sequences(
152        self, cursor: CursorWrapper, table_name: str, table_fields: tuple[Any, ...] = ()
153    ) -> list[dict[str, Any]]:
154        cursor.execute(
155            """
156            SELECT
157                s.relname AS sequence_name,
158                a.attname AS colname
159            FROM
160                pg_class s
161                JOIN pg_depend d ON d.objid = s.oid
162                    AND d.classid = 'pg_class'::regclass
163                    AND d.refclassid = 'pg_class'::regclass
164                JOIN pg_attribute a ON d.refobjid = a.attrelid
165                    AND d.refobjsubid = a.attnum
166                JOIN pg_class tbl ON tbl.oid = d.refobjid
167                    AND tbl.relname = %s
168                    AND pg_catalog.pg_table_is_visible(tbl.oid)
169            WHERE
170                s.relkind = 'S';
171        """,
172            [table_name],
173        )
174        return [
175            {"name": row[0], "table": table_name, "column": row[1]}
176            for row in cursor.fetchall()
177        ]
178
179    def get_relations(
180        self, cursor: CursorWrapper, table_name: str
181    ) -> dict[str, tuple[str, str]]:
182        """
183        Return a dictionary of {field_name: (field_name_other_table, other_table)}
184        representing all foreign keys in the given table.
185        """
186        cursor.execute(
187            """
188            SELECT a1.attname, c2.relname, a2.attname
189            FROM pg_constraint con
190            LEFT JOIN pg_class c1 ON con.conrelid = c1.oid
191            LEFT JOIN pg_class c2 ON con.confrelid = c2.oid
192            LEFT JOIN
193                pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1]
194            LEFT JOIN
195                pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1]
196            WHERE
197                c1.relname = %s AND
198                con.contype = 'f' AND
199                c1.relnamespace = c2.relnamespace AND
200                pg_catalog.pg_table_is_visible(c1.oid)
201        """,
202            [table_name],
203        )
204        return {row[0]: (row[2], row[1]) for row in cursor.fetchall()}
205
206    def get_constraints(
207        self, cursor: CursorWrapper, table_name: str
208    ) -> dict[str, dict[str, Any]]:
209        """
210        Retrieve any constraints or keys (unique, pk, fk, check, index) across
211        one or more columns. Also retrieve the definition of expression-based
212        indexes.
213        """
214        constraints: dict[str, dict[str, Any]] = {}
215        # Loop over the key table, collecting things as constraints. The column
216        # array must return column names in the same order in which they were
217        # created.
218        cursor.execute(
219            """
220            SELECT
221                c.conname,
222                array(
223                    SELECT attname
224                    FROM unnest(c.conkey) WITH ORDINALITY cols(colid, arridx)
225                    JOIN pg_attribute AS ca ON cols.colid = ca.attnum
226                    WHERE ca.attrelid = c.conrelid
227                    ORDER BY cols.arridx
228                ),
229                c.contype,
230                (SELECT fkc.relname || '.' || fka.attname
231                FROM pg_attribute AS fka
232                JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
233                WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]),
234                cl.reloptions
235            FROM pg_constraint AS c
236            JOIN pg_class AS cl ON c.conrelid = cl.oid
237            WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
238        """,
239            [table_name],
240        )
241        for constraint, columns, kind, used_cols, options in cursor.fetchall():
242            constraints[constraint] = {
243                "columns": columns,
244                "primary_key": kind == "p",
245                "unique": kind in ["p", "u"],
246                "foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None,
247                "check": kind == "c",
248                "index": False,
249                "definition": None,
250                "options": options,
251            }
252        # Now get indexes
253        cursor.execute(
254            """
255            SELECT
256                indexname,
257                array_agg(attname ORDER BY arridx),
258                indisunique,
259                indisprimary,
260                array_agg(ordering ORDER BY arridx),
261                amname,
262                exprdef,
263                s2.attoptions
264            FROM (
265                SELECT
266                    c2.relname as indexname, idx.*, attr.attname, am.amname,
267                    CASE
268                        WHEN idx.indexprs IS NOT NULL THEN
269                            pg_get_indexdef(idx.indexrelid)
270                    END AS exprdef,
271                    CASE am.amname
272                        WHEN %s THEN
273                            CASE (option & 1)
274                                WHEN 1 THEN 'DESC' ELSE 'ASC'
275                            END
276                    END as ordering,
277                    c2.reloptions as attoptions
278                FROM (
279                    SELECT *
280                    FROM
281                        pg_index i,
282                        unnest(i.indkey, i.indoption)
283                            WITH ORDINALITY koi(key, option, arridx)
284                ) idx
285                LEFT JOIN pg_class c ON idx.indrelid = c.oid
286                LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
287                LEFT JOIN pg_am am ON c2.relam = am.oid
288                LEFT JOIN
289                    pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
290                WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
291            ) s2
292            GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
293        """,
294            [self.index_default_access_method, table_name],
295        )
296        for (
297            index,
298            columns,
299            unique,
300            primary,
301            orders,
302            type_,
303            definition,
304            options,
305        ) in cursor.fetchall():
306            if index not in constraints:
307                basic_index = (
308                    type_ == self.index_default_access_method and options is None
309                )
310                constraints[index] = {
311                    "columns": columns if columns != [None] else [],
312                    "orders": orders if orders != [None] else [],
313                    "primary_key": primary,
314                    "unique": unique,
315                    "foreign_key": None,
316                    "check": False,
317                    "index": True,
318                    "type": Index.suffix if basic_index else type_,
319                    "definition": definition,
320                    "options": options,
321                }
322        return constraints