1from __future__ import annotations
  2
  3from collections.abc import Sequence
  4from typing import Any, NamedTuple
  5
  6from plain.models.backends.base.introspection import BaseDatabaseIntrospection
  7from plain.models.backends.utils import CursorWrapper
  8from plain.models.indexes import Index
  9
 10
 11class FieldInfo(NamedTuple):
 12    """PostgreSQL-specific FieldInfo extending base with autofield and comment."""
 13
 14    # Fields from BaseFieldInfo
 15    name: str
 16    type_code: Any
 17    display_size: int | None
 18    internal_size: int | None
 19    precision: int | None
 20    scale: int | None
 21    null_ok: bool | None
 22    default: Any
 23    # PostgreSQL-specific extensions
 24    is_autofield: bool
 25    comment: str | None
 26
 27
 28class TableInfo(NamedTuple):
 29    """PostgreSQL-specific TableInfo extending base with comment support."""
 30
 31    # Fields from BaseTableInfo
 32    name: str
 33    type: str
 34    # PostgreSQL-specific extension
 35    comment: str | None
 36
 37
 38class DatabaseIntrospection(BaseDatabaseIntrospection):
 39    # Maps type codes to Plain Field types.
 40    data_types_reverse = {
 41        16: "BooleanField",
 42        17: "BinaryField",
 43        20: "BigIntegerField",
 44        21: "SmallIntegerField",
 45        23: "IntegerField",
 46        25: "TextField",
 47        700: "FloatField",
 48        701: "FloatField",
 49        869: "GenericIPAddressField",
 50        1042: "CharField",  # blank-padded
 51        1043: "CharField",
 52        1082: "DateField",
 53        1083: "TimeField",
 54        1114: "DateTimeField",
 55        1184: "DateTimeField",
 56        1186: "DurationField",
 57        1266: "TimeField",
 58        1700: "DecimalField",
 59        2950: "UUIDField",
 60        3802: "JSONField",
 61    }
 62    # A hook for subclasses.
 63    index_default_access_method = "btree"
 64
 65    ignored_tables: list[str] = []
 66
 67    def get_field_type(self, data_type: Any, description: Any) -> str:
 68        field_type = super().get_field_type(data_type, description)
 69        if description.is_autofield or (
 70            # Required for pre-Plain 4.1 serial columns.
 71            description.default and "nextval" in description.default
 72        ):
 73            if field_type == "BigIntegerField":
 74                return "PrimaryKeyField"
 75        return field_type
 76
 77    def get_table_list(self, cursor: CursorWrapper) -> Sequence[TableInfo]:
 78        """Return a list of table and view names in the current database."""
 79        cursor.execute(
 80            """
 81            SELECT
 82                c.relname,
 83                CASE
 84                    WHEN c.relispartition THEN 'p'
 85                    WHEN c.relkind IN ('m', 'v') THEN 'v'
 86                    ELSE 't'
 87                END,
 88                obj_description(c.oid, 'pg_class')
 89            FROM pg_catalog.pg_class c
 90            LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
 91            WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
 92                AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
 93                AND pg_catalog.pg_table_is_visible(c.oid)
 94        """
 95        )
 96        return [
 97            TableInfo(*row)
 98            for row in cursor.fetchall()
 99            if row[0] not in self.ignored_tables
100        ]
101
102    def get_table_description(
103        self, cursor: CursorWrapper, table_name: str
104    ) -> Sequence[FieldInfo]:
105        """
106        Return a description of the table with the DB-API cursor.description
107        interface.
108        """
109        # Query the pg_catalog tables as cursor.description does not reliably
110        # return the nullable property and information_schema.columns does not
111        # contain details of materialized views.
112        cursor.execute(
113            """
114            SELECT
115                a.attname AS column_name,
116                NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable,
117                pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
118                a.attidentity != '' AS is_autofield,
119                col_description(a.attrelid, a.attnum) AS column_comment
120            FROM pg_attribute a
121            LEFT JOIN pg_attrdef ad ON a.attrelid = ad.adrelid AND a.attnum = ad.adnum
122            JOIN pg_type t ON a.atttypid = t.oid
123            JOIN pg_class c ON a.attrelid = c.oid
124            JOIN pg_namespace n ON c.relnamespace = n.oid
125            WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
126                AND c.relname = %s
127                AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
128                AND pg_catalog.pg_table_is_visible(c.oid)
129        """,
130            [table_name],
131        )
132        field_map = {line[0]: line[1:] for line in cursor.fetchall()}
133        cursor.execute(
134            f"SELECT * FROM {self.connection.ops.quote_name(table_name)} LIMIT 1"
135        )
136        return [
137            FieldInfo(
138                line.name,
139                line.type_code,
140                line.internal_size if line.display_size is None else line.display_size,
141                line.internal_size,
142                line.precision,
143                line.scale,
144                *field_map[line.name],
145            )
146            for line in cursor.description
147        ]
148
149    def get_sequences(
150        self, cursor: CursorWrapper, table_name: str, table_fields: tuple[Any, ...] = ()
151    ) -> list[dict[str, Any]]:
152        cursor.execute(
153            """
154            SELECT
155                s.relname AS sequence_name,
156                a.attname AS colname
157            FROM
158                pg_class s
159                JOIN pg_depend d ON d.objid = s.oid
160                    AND d.classid = 'pg_class'::regclass
161                    AND d.refclassid = 'pg_class'::regclass
162                JOIN pg_attribute a ON d.refobjid = a.attrelid
163                    AND d.refobjsubid = a.attnum
164                JOIN pg_class tbl ON tbl.oid = d.refobjid
165                    AND tbl.relname = %s
166                    AND pg_catalog.pg_table_is_visible(tbl.oid)
167            WHERE
168                s.relkind = 'S';
169        """,
170            [table_name],
171        )
172        return [
173            {"name": row[0], "table": table_name, "column": row[1]}
174            for row in cursor.fetchall()
175        ]
176
177    def get_relations(
178        self, cursor: CursorWrapper, table_name: str
179    ) -> dict[str, tuple[str, str]]:
180        """
181        Return a dictionary of {field_name: (field_name_other_table, other_table)}
182        representing all foreign keys in the given table.
183        """
184        cursor.execute(
185            """
186            SELECT a1.attname, c2.relname, a2.attname
187            FROM pg_constraint con
188            LEFT JOIN pg_class c1 ON con.conrelid = c1.oid
189            LEFT JOIN pg_class c2 ON con.confrelid = c2.oid
190            LEFT JOIN
191                pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1]
192            LEFT JOIN
193                pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1]
194            WHERE
195                c1.relname = %s AND
196                con.contype = 'f' AND
197                c1.relnamespace = c2.relnamespace AND
198                pg_catalog.pg_table_is_visible(c1.oid)
199        """,
200            [table_name],
201        )
202        return {row[0]: (row[2], row[1]) for row in cursor.fetchall()}
203
204    def get_constraints(
205        self, cursor: CursorWrapper, table_name: str
206    ) -> dict[str, dict[str, Any]]:
207        """
208        Retrieve any constraints or keys (unique, pk, fk, check, index) across
209        one or more columns. Also retrieve the definition of expression-based
210        indexes.
211        """
212        constraints: dict[str, dict[str, Any]] = {}
213        # Loop over the key table, collecting things as constraints. The column
214        # array must return column names in the same order in which they were
215        # created.
216        cursor.execute(
217            """
218            SELECT
219                c.conname,
220                array(
221                    SELECT attname
222                    FROM unnest(c.conkey) WITH ORDINALITY cols(colid, arridx)
223                    JOIN pg_attribute AS ca ON cols.colid = ca.attnum
224                    WHERE ca.attrelid = c.conrelid
225                    ORDER BY cols.arridx
226                ),
227                c.contype,
228                (SELECT fkc.relname || '.' || fka.attname
229                FROM pg_attribute AS fka
230                JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
231                WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]),
232                cl.reloptions
233            FROM pg_constraint AS c
234            JOIN pg_class AS cl ON c.conrelid = cl.oid
235            WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
236        """,
237            [table_name],
238        )
239        for constraint, columns, kind, used_cols, options in cursor.fetchall():
240            constraints[constraint] = {
241                "columns": columns,
242                "primary_key": kind == "p",
243                "unique": kind in ["p", "u"],
244                "foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None,
245                "check": kind == "c",
246                "index": False,
247                "definition": None,
248                "options": options,
249            }
250        # Now get indexes
251        cursor.execute(
252            """
253            SELECT
254                indexname,
255                array_agg(attname ORDER BY arridx),
256                indisunique,
257                indisprimary,
258                array_agg(ordering ORDER BY arridx),
259                amname,
260                exprdef,
261                s2.attoptions
262            FROM (
263                SELECT
264                    c2.relname as indexname, idx.*, attr.attname, am.amname,
265                    CASE
266                        WHEN idx.indexprs IS NOT NULL THEN
267                            pg_get_indexdef(idx.indexrelid)
268                    END AS exprdef,
269                    CASE am.amname
270                        WHEN %s THEN
271                            CASE (option & 1)
272                                WHEN 1 THEN 'DESC' ELSE 'ASC'
273                            END
274                    END as ordering,
275                    c2.reloptions as attoptions
276                FROM (
277                    SELECT *
278                    FROM
279                        pg_index i,
280                        unnest(i.indkey, i.indoption)
281                            WITH ORDINALITY koi(key, option, arridx)
282                ) idx
283                LEFT JOIN pg_class c ON idx.indrelid = c.oid
284                LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
285                LEFT JOIN pg_am am ON c2.relam = am.oid
286                LEFT JOIN
287                    pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
288                WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
289            ) s2
290            GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
291        """,
292            [self.index_default_access_method, table_name],
293        )
294        for (
295            index,
296            columns,
297            unique,
298            primary,
299            orders,
300            type_,
301            definition,
302            options,
303        ) in cursor.fetchall():
304            if index not in constraints:
305                basic_index = (
306                    type_ == self.index_default_access_method and options is None
307                )
308                constraints[index] = {
309                    "columns": columns if columns != [None] else [],
310                    "orders": orders if orders != [None] else [],
311                    "primary_key": primary,
312                    "unique": unique,
313                    "foreign_key": None,
314                    "check": False,
315                    "index": True,
316                    "type": Index.suffix if basic_index else type_,
317                    "definition": definition,
318                    "options": options,
319                }
320        return constraints