1from __future__ import annotations
2
3from collections.abc import Sequence
4from typing import Any, NamedTuple
5
6from plain.models.backends.base.introspection import BaseDatabaseIntrospection
7from plain.models.backends.utils import CursorWrapper
8from plain.models.indexes import Index
9
10
11class FieldInfo(NamedTuple):
12 """PostgreSQL-specific FieldInfo extending base with autofield and comment."""
13
14 # Fields from BaseFieldInfo
15 name: str
16 type_code: Any
17 display_size: int | None
18 internal_size: int | None
19 precision: int | None
20 scale: int | None
21 null_ok: bool | None
22 default: Any
23 # PostgreSQL-specific extensions
24 is_autofield: bool
25 comment: str | None
26
27
28class TableInfo(NamedTuple):
29 """PostgreSQL-specific TableInfo extending base with comment support."""
30
31 # Fields from BaseTableInfo
32 name: str
33 type: str
34 # PostgreSQL-specific extension
35 comment: str | None
36
37
38class DatabaseIntrospection(BaseDatabaseIntrospection):
39 # Maps type codes to Plain Field types.
40 data_types_reverse = {
41 16: "BooleanField",
42 17: "BinaryField",
43 20: "BigIntegerField",
44 21: "SmallIntegerField",
45 23: "IntegerField",
46 25: "TextField",
47 700: "FloatField",
48 701: "FloatField",
49 869: "GenericIPAddressField",
50 1042: "CharField", # blank-padded
51 1043: "CharField",
52 1082: "DateField",
53 1083: "TimeField",
54 1114: "DateTimeField",
55 1184: "DateTimeField",
56 1186: "DurationField",
57 1266: "TimeField",
58 1700: "DecimalField",
59 2950: "UUIDField",
60 3802: "JSONField",
61 }
62 # A hook for subclasses.
63 index_default_access_method = "btree"
64
65 ignored_tables: list[str] = []
66
67 def get_field_type(self, data_type: Any, description: Any) -> str:
68 field_type = super().get_field_type(data_type, description)
69 if description.is_autofield or (
70 # Required for pre-Plain 4.1 serial columns.
71 description.default and "nextval" in description.default
72 ):
73 if field_type == "BigIntegerField":
74 return "PrimaryKeyField"
75 return field_type
76
77 def get_table_list(self, cursor: CursorWrapper) -> Sequence[TableInfo]:
78 """Return a list of table and view names in the current database."""
79 cursor.execute(
80 """
81 SELECT
82 c.relname,
83 CASE
84 WHEN c.relispartition THEN 'p'
85 WHEN c.relkind IN ('m', 'v') THEN 'v'
86 ELSE 't'
87 END,
88 obj_description(c.oid, 'pg_class')
89 FROM pg_catalog.pg_class c
90 LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
91 WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
92 AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
93 AND pg_catalog.pg_table_is_visible(c.oid)
94 """
95 )
96 return [
97 TableInfo(*row)
98 for row in cursor.fetchall()
99 if row[0] not in self.ignored_tables
100 ]
101
102 def get_table_description(
103 self, cursor: CursorWrapper, table_name: str
104 ) -> Sequence[FieldInfo]:
105 """
106 Return a description of the table with the DB-API cursor.description
107 interface.
108 """
109 # Query the pg_catalog tables as cursor.description does not reliably
110 # return the nullable property and information_schema.columns does not
111 # contain details of materialized views.
112 cursor.execute(
113 """
114 SELECT
115 a.attname AS column_name,
116 NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable,
117 pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
118 a.attidentity != '' AS is_autofield,
119 col_description(a.attrelid, a.attnum) AS column_comment
120 FROM pg_attribute a
121 LEFT JOIN pg_attrdef ad ON a.attrelid = ad.adrelid AND a.attnum = ad.adnum
122 JOIN pg_type t ON a.atttypid = t.oid
123 JOIN pg_class c ON a.attrelid = c.oid
124 JOIN pg_namespace n ON c.relnamespace = n.oid
125 WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
126 AND c.relname = %s
127 AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
128 AND pg_catalog.pg_table_is_visible(c.oid)
129 """,
130 [table_name],
131 )
132 field_map = {line[0]: line[1:] for line in cursor.fetchall()}
133 cursor.execute(
134 f"SELECT * FROM {self.connection.ops.quote_name(table_name)} LIMIT 1"
135 )
136 return [
137 FieldInfo(
138 line.name,
139 line.type_code,
140 line.internal_size if line.display_size is None else line.display_size,
141 line.internal_size,
142 line.precision,
143 line.scale,
144 *field_map[line.name],
145 )
146 for line in cursor.description
147 ]
148
149 def get_sequences(
150 self, cursor: CursorWrapper, table_name: str, table_fields: tuple[Any, ...] = ()
151 ) -> list[dict[str, Any]]:
152 cursor.execute(
153 """
154 SELECT
155 s.relname AS sequence_name,
156 a.attname AS colname
157 FROM
158 pg_class s
159 JOIN pg_depend d ON d.objid = s.oid
160 AND d.classid = 'pg_class'::regclass
161 AND d.refclassid = 'pg_class'::regclass
162 JOIN pg_attribute a ON d.refobjid = a.attrelid
163 AND d.refobjsubid = a.attnum
164 JOIN pg_class tbl ON tbl.oid = d.refobjid
165 AND tbl.relname = %s
166 AND pg_catalog.pg_table_is_visible(tbl.oid)
167 WHERE
168 s.relkind = 'S';
169 """,
170 [table_name],
171 )
172 return [
173 {"name": row[0], "table": table_name, "column": row[1]}
174 for row in cursor.fetchall()
175 ]
176
177 def get_relations(
178 self, cursor: CursorWrapper, table_name: str
179 ) -> dict[str, tuple[str, str]]:
180 """
181 Return a dictionary of {field_name: (field_name_other_table, other_table)}
182 representing all foreign keys in the given table.
183 """
184 cursor.execute(
185 """
186 SELECT a1.attname, c2.relname, a2.attname
187 FROM pg_constraint con
188 LEFT JOIN pg_class c1 ON con.conrelid = c1.oid
189 LEFT JOIN pg_class c2 ON con.confrelid = c2.oid
190 LEFT JOIN
191 pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1]
192 LEFT JOIN
193 pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1]
194 WHERE
195 c1.relname = %s AND
196 con.contype = 'f' AND
197 c1.relnamespace = c2.relnamespace AND
198 pg_catalog.pg_table_is_visible(c1.oid)
199 """,
200 [table_name],
201 )
202 return {row[0]: (row[2], row[1]) for row in cursor.fetchall()}
203
204 def get_constraints(
205 self, cursor: CursorWrapper, table_name: str
206 ) -> dict[str, dict[str, Any]]:
207 """
208 Retrieve any constraints or keys (unique, pk, fk, check, index) across
209 one or more columns. Also retrieve the definition of expression-based
210 indexes.
211 """
212 constraints: dict[str, dict[str, Any]] = {}
213 # Loop over the key table, collecting things as constraints. The column
214 # array must return column names in the same order in which they were
215 # created.
216 cursor.execute(
217 """
218 SELECT
219 c.conname,
220 array(
221 SELECT attname
222 FROM unnest(c.conkey) WITH ORDINALITY cols(colid, arridx)
223 JOIN pg_attribute AS ca ON cols.colid = ca.attnum
224 WHERE ca.attrelid = c.conrelid
225 ORDER BY cols.arridx
226 ),
227 c.contype,
228 (SELECT fkc.relname || '.' || fka.attname
229 FROM pg_attribute AS fka
230 JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
231 WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]),
232 cl.reloptions
233 FROM pg_constraint AS c
234 JOIN pg_class AS cl ON c.conrelid = cl.oid
235 WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
236 """,
237 [table_name],
238 )
239 for constraint, columns, kind, used_cols, options in cursor.fetchall():
240 constraints[constraint] = {
241 "columns": columns,
242 "primary_key": kind == "p",
243 "unique": kind in ["p", "u"],
244 "foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None,
245 "check": kind == "c",
246 "index": False,
247 "definition": None,
248 "options": options,
249 }
250 # Now get indexes
251 cursor.execute(
252 """
253 SELECT
254 indexname,
255 array_agg(attname ORDER BY arridx),
256 indisunique,
257 indisprimary,
258 array_agg(ordering ORDER BY arridx),
259 amname,
260 exprdef,
261 s2.attoptions
262 FROM (
263 SELECT
264 c2.relname as indexname, idx.*, attr.attname, am.amname,
265 CASE
266 WHEN idx.indexprs IS NOT NULL THEN
267 pg_get_indexdef(idx.indexrelid)
268 END AS exprdef,
269 CASE am.amname
270 WHEN %s THEN
271 CASE (option & 1)
272 WHEN 1 THEN 'DESC' ELSE 'ASC'
273 END
274 END as ordering,
275 c2.reloptions as attoptions
276 FROM (
277 SELECT *
278 FROM
279 pg_index i,
280 unnest(i.indkey, i.indoption)
281 WITH ORDINALITY koi(key, option, arridx)
282 ) idx
283 LEFT JOIN pg_class c ON idx.indrelid = c.oid
284 LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
285 LEFT JOIN pg_am am ON c2.relam = am.oid
286 LEFT JOIN
287 pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
288 WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
289 ) s2
290 GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
291 """,
292 [self.index_default_access_method, table_name],
293 )
294 for (
295 index,
296 columns,
297 unique,
298 primary,
299 orders,
300 type_,
301 definition,
302 options,
303 ) in cursor.fetchall():
304 if index not in constraints:
305 basic_index = (
306 type_ == self.index_default_access_method and options is None
307 )
308 constraints[index] = {
309 "columns": columns if columns != [None] else [],
310 "orders": orders if orders != [None] else [],
311 "primary_key": primary,
312 "unique": unique,
313 "foreign_key": None,
314 "check": False,
315 "index": True,
316 "type": Index.suffix if basic_index else type_,
317 "definition": definition,
318 "options": options,
319 }
320 return constraints