Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from collections.abc import Sequence
  4from typing import Any, NamedTuple
  5
  6from plain.models.backends.base.introspection import BaseDatabaseIntrospection
  7from plain.models.backends.utils import CursorWrapper
  8from plain.models.indexes import Index
  9
 10
 11class FieldInfo(NamedTuple):
 12    """PostgreSQL-specific FieldInfo extending base with autofield and comment."""
 13
 14    # Fields from BaseFieldInfo
 15    name: str
 16    type_code: Any
 17    display_size: int | None
 18    internal_size: int | None
 19    precision: int | None
 20    scale: int | None
 21    null_ok: bool | None
 22    default: Any
 23    collation: str | None
 24    # PostgreSQL-specific extensions
 25    is_autofield: bool
 26    comment: str | None
 27
 28
 29class TableInfo(NamedTuple):
 30    """PostgreSQL-specific TableInfo extending base with comment support."""
 31
 32    # Fields from BaseTableInfo
 33    name: str
 34    type: str
 35    # PostgreSQL-specific extension
 36    comment: str | None
 37
 38
 39class DatabaseIntrospection(BaseDatabaseIntrospection):
 40    # Maps type codes to Plain Field types.
 41    data_types_reverse = {
 42        16: "BooleanField",
 43        17: "BinaryField",
 44        20: "BigIntegerField",
 45        21: "SmallIntegerField",
 46        23: "IntegerField",
 47        25: "TextField",
 48        700: "FloatField",
 49        701: "FloatField",
 50        869: "GenericIPAddressField",
 51        1042: "CharField",  # blank-padded
 52        1043: "CharField",
 53        1082: "DateField",
 54        1083: "TimeField",
 55        1114: "DateTimeField",
 56        1184: "DateTimeField",
 57        1186: "DurationField",
 58        1266: "TimeField",
 59        1700: "DecimalField",
 60        2950: "UUIDField",
 61        3802: "JSONField",
 62    }
 63    # A hook for subclasses.
 64    index_default_access_method = "btree"
 65
 66    ignored_tables: list[str] = []
 67
 68    def get_field_type(self, data_type: Any, description: Any) -> str:
 69        field_type = super().get_field_type(data_type, description)
 70        if description.is_autofield or (
 71            # Required for pre-Plain 4.1 serial columns.
 72            description.default and "nextval" in description.default
 73        ):
 74            if field_type == "BigIntegerField":
 75                return "PrimaryKeyField"
 76        return field_type
 77
 78    def get_table_list(self, cursor: CursorWrapper) -> Sequence[TableInfo]:
 79        """Return a list of table and view names in the current database."""
 80        cursor.execute(
 81            """
 82            SELECT
 83                c.relname,
 84                CASE
 85                    WHEN c.relispartition THEN 'p'
 86                    WHEN c.relkind IN ('m', 'v') THEN 'v'
 87                    ELSE 't'
 88                END,
 89                obj_description(c.oid, 'pg_class')
 90            FROM pg_catalog.pg_class c
 91            LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
 92            WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
 93                AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
 94                AND pg_catalog.pg_table_is_visible(c.oid)
 95        """
 96        )
 97        return [
 98            TableInfo(*row)
 99            for row in cursor.fetchall()
100            if row[0] not in self.ignored_tables
101        ]
102
103    def get_table_description(
104        self, cursor: CursorWrapper, table_name: str
105    ) -> Sequence[FieldInfo]:
106        """
107        Return a description of the table with the DB-API cursor.description
108        interface.
109        """
110        # Query the pg_catalog tables as cursor.description does not reliably
111        # return the nullable property and information_schema.columns does not
112        # contain details of materialized views.
113        cursor.execute(
114            """
115            SELECT
116                a.attname AS column_name,
117                NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable,
118                pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
119                CASE WHEN collname = 'default' THEN NULL ELSE collname END AS collation,
120                a.attidentity != '' AS is_autofield,
121                col_description(a.attrelid, a.attnum) AS column_comment
122            FROM pg_attribute a
123            LEFT JOIN pg_attrdef ad ON a.attrelid = ad.adrelid AND a.attnum = ad.adnum
124            LEFT JOIN pg_collation co ON a.attcollation = co.oid
125            JOIN pg_type t ON a.atttypid = t.oid
126            JOIN pg_class c ON a.attrelid = c.oid
127            JOIN pg_namespace n ON c.relnamespace = n.oid
128            WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
129                AND c.relname = %s
130                AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
131                AND pg_catalog.pg_table_is_visible(c.oid)
132        """,
133            [table_name],
134        )
135        field_map = {line[0]: line[1:] for line in cursor.fetchall()}
136        cursor.execute(
137            f"SELECT * FROM {self.connection.ops.quote_name(table_name)} LIMIT 1"
138        )
139        return [
140            FieldInfo(
141                line.name,
142                line.type_code,
143                line.internal_size if line.display_size is None else line.display_size,
144                line.internal_size,
145                line.precision,
146                line.scale,
147                *field_map[line.name],
148            )
149            for line in cursor.description
150        ]
151
152    def get_sequences(
153        self, cursor: CursorWrapper, table_name: str, table_fields: tuple[Any, ...] = ()
154    ) -> list[dict[str, Any]]:
155        cursor.execute(
156            """
157            SELECT
158                s.relname AS sequence_name,
159                a.attname AS colname
160            FROM
161                pg_class s
162                JOIN pg_depend d ON d.objid = s.oid
163                    AND d.classid = 'pg_class'::regclass
164                    AND d.refclassid = 'pg_class'::regclass
165                JOIN pg_attribute a ON d.refobjid = a.attrelid
166                    AND d.refobjsubid = a.attnum
167                JOIN pg_class tbl ON tbl.oid = d.refobjid
168                    AND tbl.relname = %s
169                    AND pg_catalog.pg_table_is_visible(tbl.oid)
170            WHERE
171                s.relkind = 'S';
172        """,
173            [table_name],
174        )
175        return [
176            {"name": row[0], "table": table_name, "column": row[1]}
177            for row in cursor.fetchall()
178        ]
179
180    def get_relations(
181        self, cursor: CursorWrapper, table_name: str
182    ) -> dict[str, tuple[str, str]]:
183        """
184        Return a dictionary of {field_name: (field_name_other_table, other_table)}
185        representing all foreign keys in the given table.
186        """
187        cursor.execute(
188            """
189            SELECT a1.attname, c2.relname, a2.attname
190            FROM pg_constraint con
191            LEFT JOIN pg_class c1 ON con.conrelid = c1.oid
192            LEFT JOIN pg_class c2 ON con.confrelid = c2.oid
193            LEFT JOIN
194                pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1]
195            LEFT JOIN
196                pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1]
197            WHERE
198                c1.relname = %s AND
199                con.contype = 'f' AND
200                c1.relnamespace = c2.relnamespace AND
201                pg_catalog.pg_table_is_visible(c1.oid)
202        """,
203            [table_name],
204        )
205        return {row[0]: (row[2], row[1]) for row in cursor.fetchall()}
206
207    def get_constraints(
208        self, cursor: CursorWrapper, table_name: str
209    ) -> dict[str, dict[str, Any]]:
210        """
211        Retrieve any constraints or keys (unique, pk, fk, check, index) across
212        one or more columns. Also retrieve the definition of expression-based
213        indexes.
214        """
215        constraints: dict[str, dict[str, Any]] = {}
216        # Loop over the key table, collecting things as constraints. The column
217        # array must return column names in the same order in which they were
218        # created.
219        cursor.execute(
220            """
221            SELECT
222                c.conname,
223                array(
224                    SELECT attname
225                    FROM unnest(c.conkey) WITH ORDINALITY cols(colid, arridx)
226                    JOIN pg_attribute AS ca ON cols.colid = ca.attnum
227                    WHERE ca.attrelid = c.conrelid
228                    ORDER BY cols.arridx
229                ),
230                c.contype,
231                (SELECT fkc.relname || '.' || fka.attname
232                FROM pg_attribute AS fka
233                JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
234                WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]),
235                cl.reloptions
236            FROM pg_constraint AS c
237            JOIN pg_class AS cl ON c.conrelid = cl.oid
238            WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
239        """,
240            [table_name],
241        )
242        for constraint, columns, kind, used_cols, options in cursor.fetchall():
243            constraints[constraint] = {
244                "columns": columns,
245                "primary_key": kind == "p",
246                "unique": kind in ["p", "u"],
247                "foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None,
248                "check": kind == "c",
249                "index": False,
250                "definition": None,
251                "options": options,
252            }
253        # Now get indexes
254        cursor.execute(
255            """
256            SELECT
257                indexname,
258                array_agg(attname ORDER BY arridx),
259                indisunique,
260                indisprimary,
261                array_agg(ordering ORDER BY arridx),
262                amname,
263                exprdef,
264                s2.attoptions
265            FROM (
266                SELECT
267                    c2.relname as indexname, idx.*, attr.attname, am.amname,
268                    CASE
269                        WHEN idx.indexprs IS NOT NULL THEN
270                            pg_get_indexdef(idx.indexrelid)
271                    END AS exprdef,
272                    CASE am.amname
273                        WHEN %s THEN
274                            CASE (option & 1)
275                                WHEN 1 THEN 'DESC' ELSE 'ASC'
276                            END
277                    END as ordering,
278                    c2.reloptions as attoptions
279                FROM (
280                    SELECT *
281                    FROM
282                        pg_index i,
283                        unnest(i.indkey, i.indoption)
284                            WITH ORDINALITY koi(key, option, arridx)
285                ) idx
286                LEFT JOIN pg_class c ON idx.indrelid = c.oid
287                LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
288                LEFT JOIN pg_am am ON c2.relam = am.oid
289                LEFT JOIN
290                    pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
291                WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
292            ) s2
293            GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
294        """,
295            [self.index_default_access_method, table_name],
296        )
297        for (
298            index,
299            columns,
300            unique,
301            primary,
302            orders,
303            type_,
304            definition,
305            options,
306        ) in cursor.fetchall():
307            if index not in constraints:
308                basic_index = (
309                    type_ == self.index_default_access_method and options is None
310                )
311                constraints[index] = {
312                    "columns": columns if columns != [None] else [],
313                    "orders": orders if orders != [None] else [],
314                    "primary_key": primary,
315                    "unique": unique,
316                    "foreign_key": None,
317                    "check": False,
318                    "index": True,
319                    "type": Index.suffix if basic_index else type_,
320                    "definition": definition,
321                    "options": options,
322                }
323        return constraints