1from __future__ import annotations
2
3from collections.abc import Sequence
4from typing import Any, NamedTuple
5
6from plain.models.backends.base.introspection import BaseDatabaseIntrospection
7from plain.models.backends.utils import CursorWrapper
8from plain.models.indexes import Index
9
10
11class FieldInfo(NamedTuple):
12 """PostgreSQL-specific FieldInfo extending base with autofield and comment."""
13
14 # Fields from BaseFieldInfo
15 name: str
16 type_code: Any
17 display_size: int | None
18 internal_size: int | None
19 precision: int | None
20 scale: int | None
21 null_ok: bool | None
22 default: Any
23 collation: str | None
24 # PostgreSQL-specific extensions
25 is_autofield: bool
26 comment: str | None
27
28
29class TableInfo(NamedTuple):
30 """PostgreSQL-specific TableInfo extending base with comment support."""
31
32 # Fields from BaseTableInfo
33 name: str
34 type: str
35 # PostgreSQL-specific extension
36 comment: str | None
37
38
39class DatabaseIntrospection(BaseDatabaseIntrospection):
40 # Maps type codes to Plain Field types.
41 data_types_reverse = {
42 16: "BooleanField",
43 17: "BinaryField",
44 20: "BigIntegerField",
45 21: "SmallIntegerField",
46 23: "IntegerField",
47 25: "TextField",
48 700: "FloatField",
49 701: "FloatField",
50 869: "GenericIPAddressField",
51 1042: "CharField", # blank-padded
52 1043: "CharField",
53 1082: "DateField",
54 1083: "TimeField",
55 1114: "DateTimeField",
56 1184: "DateTimeField",
57 1186: "DurationField",
58 1266: "TimeField",
59 1700: "DecimalField",
60 2950: "UUIDField",
61 3802: "JSONField",
62 }
63 # A hook for subclasses.
64 index_default_access_method = "btree"
65
66 ignored_tables: list[str] = []
67
68 def get_field_type(self, data_type: Any, description: Any) -> str:
69 field_type = super().get_field_type(data_type, description)
70 if description.is_autofield or (
71 # Required for pre-Plain 4.1 serial columns.
72 description.default and "nextval" in description.default
73 ):
74 if field_type == "BigIntegerField":
75 return "PrimaryKeyField"
76 return field_type
77
78 def get_table_list(self, cursor: CursorWrapper) -> Sequence[TableInfo]:
79 """Return a list of table and view names in the current database."""
80 cursor.execute(
81 """
82 SELECT
83 c.relname,
84 CASE
85 WHEN c.relispartition THEN 'p'
86 WHEN c.relkind IN ('m', 'v') THEN 'v'
87 ELSE 't'
88 END,
89 obj_description(c.oid, 'pg_class')
90 FROM pg_catalog.pg_class c
91 LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
92 WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
93 AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
94 AND pg_catalog.pg_table_is_visible(c.oid)
95 """
96 )
97 return [
98 TableInfo(*row)
99 for row in cursor.fetchall()
100 if row[0] not in self.ignored_tables
101 ]
102
103 def get_table_description(
104 self, cursor: CursorWrapper, table_name: str
105 ) -> Sequence[FieldInfo]:
106 """
107 Return a description of the table with the DB-API cursor.description
108 interface.
109 """
110 # Query the pg_catalog tables as cursor.description does not reliably
111 # return the nullable property and information_schema.columns does not
112 # contain details of materialized views.
113 cursor.execute(
114 """
115 SELECT
116 a.attname AS column_name,
117 NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable,
118 pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
119 CASE WHEN collname = 'default' THEN NULL ELSE collname END AS collation,
120 a.attidentity != '' AS is_autofield,
121 col_description(a.attrelid, a.attnum) AS column_comment
122 FROM pg_attribute a
123 LEFT JOIN pg_attrdef ad ON a.attrelid = ad.adrelid AND a.attnum = ad.adnum
124 LEFT JOIN pg_collation co ON a.attcollation = co.oid
125 JOIN pg_type t ON a.atttypid = t.oid
126 JOIN pg_class c ON a.attrelid = c.oid
127 JOIN pg_namespace n ON c.relnamespace = n.oid
128 WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
129 AND c.relname = %s
130 AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
131 AND pg_catalog.pg_table_is_visible(c.oid)
132 """,
133 [table_name],
134 )
135 field_map = {line[0]: line[1:] for line in cursor.fetchall()}
136 cursor.execute(
137 f"SELECT * FROM {self.connection.ops.quote_name(table_name)} LIMIT 1"
138 )
139 return [
140 FieldInfo(
141 line.name,
142 line.type_code,
143 line.internal_size if line.display_size is None else line.display_size,
144 line.internal_size,
145 line.precision,
146 line.scale,
147 *field_map[line.name],
148 )
149 for line in cursor.description
150 ]
151
152 def get_sequences(
153 self, cursor: CursorWrapper, table_name: str, table_fields: tuple[Any, ...] = ()
154 ) -> list[dict[str, Any]]:
155 cursor.execute(
156 """
157 SELECT
158 s.relname AS sequence_name,
159 a.attname AS colname
160 FROM
161 pg_class s
162 JOIN pg_depend d ON d.objid = s.oid
163 AND d.classid = 'pg_class'::regclass
164 AND d.refclassid = 'pg_class'::regclass
165 JOIN pg_attribute a ON d.refobjid = a.attrelid
166 AND d.refobjsubid = a.attnum
167 JOIN pg_class tbl ON tbl.oid = d.refobjid
168 AND tbl.relname = %s
169 AND pg_catalog.pg_table_is_visible(tbl.oid)
170 WHERE
171 s.relkind = 'S';
172 """,
173 [table_name],
174 )
175 return [
176 {"name": row[0], "table": table_name, "column": row[1]}
177 for row in cursor.fetchall()
178 ]
179
180 def get_relations(
181 self, cursor: CursorWrapper, table_name: str
182 ) -> dict[str, tuple[str, str]]:
183 """
184 Return a dictionary of {field_name: (field_name_other_table, other_table)}
185 representing all foreign keys in the given table.
186 """
187 cursor.execute(
188 """
189 SELECT a1.attname, c2.relname, a2.attname
190 FROM pg_constraint con
191 LEFT JOIN pg_class c1 ON con.conrelid = c1.oid
192 LEFT JOIN pg_class c2 ON con.confrelid = c2.oid
193 LEFT JOIN
194 pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1]
195 LEFT JOIN
196 pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1]
197 WHERE
198 c1.relname = %s AND
199 con.contype = 'f' AND
200 c1.relnamespace = c2.relnamespace AND
201 pg_catalog.pg_table_is_visible(c1.oid)
202 """,
203 [table_name],
204 )
205 return {row[0]: (row[2], row[1]) for row in cursor.fetchall()}
206
207 def get_constraints(
208 self, cursor: CursorWrapper, table_name: str
209 ) -> dict[str, dict[str, Any]]:
210 """
211 Retrieve any constraints or keys (unique, pk, fk, check, index) across
212 one or more columns. Also retrieve the definition of expression-based
213 indexes.
214 """
215 constraints: dict[str, dict[str, Any]] = {}
216 # Loop over the key table, collecting things as constraints. The column
217 # array must return column names in the same order in which they were
218 # created.
219 cursor.execute(
220 """
221 SELECT
222 c.conname,
223 array(
224 SELECT attname
225 FROM unnest(c.conkey) WITH ORDINALITY cols(colid, arridx)
226 JOIN pg_attribute AS ca ON cols.colid = ca.attnum
227 WHERE ca.attrelid = c.conrelid
228 ORDER BY cols.arridx
229 ),
230 c.contype,
231 (SELECT fkc.relname || '.' || fka.attname
232 FROM pg_attribute AS fka
233 JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
234 WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]),
235 cl.reloptions
236 FROM pg_constraint AS c
237 JOIN pg_class AS cl ON c.conrelid = cl.oid
238 WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
239 """,
240 [table_name],
241 )
242 for constraint, columns, kind, used_cols, options in cursor.fetchall():
243 constraints[constraint] = {
244 "columns": columns,
245 "primary_key": kind == "p",
246 "unique": kind in ["p", "u"],
247 "foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None,
248 "check": kind == "c",
249 "index": False,
250 "definition": None,
251 "options": options,
252 }
253 # Now get indexes
254 cursor.execute(
255 """
256 SELECT
257 indexname,
258 array_agg(attname ORDER BY arridx),
259 indisunique,
260 indisprimary,
261 array_agg(ordering ORDER BY arridx),
262 amname,
263 exprdef,
264 s2.attoptions
265 FROM (
266 SELECT
267 c2.relname as indexname, idx.*, attr.attname, am.amname,
268 CASE
269 WHEN idx.indexprs IS NOT NULL THEN
270 pg_get_indexdef(idx.indexrelid)
271 END AS exprdef,
272 CASE am.amname
273 WHEN %s THEN
274 CASE (option & 1)
275 WHEN 1 THEN 'DESC' ELSE 'ASC'
276 END
277 END as ordering,
278 c2.reloptions as attoptions
279 FROM (
280 SELECT *
281 FROM
282 pg_index i,
283 unnest(i.indkey, i.indoption)
284 WITH ORDINALITY koi(key, option, arridx)
285 ) idx
286 LEFT JOIN pg_class c ON idx.indrelid = c.oid
287 LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
288 LEFT JOIN pg_am am ON c2.relam = am.oid
289 LEFT JOIN
290 pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
291 WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
292 ) s2
293 GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
294 """,
295 [self.index_default_access_method, table_name],
296 )
297 for (
298 index,
299 columns,
300 unique,
301 primary,
302 orders,
303 type_,
304 definition,
305 options,
306 ) in cursor.fetchall():
307 if index not in constraints:
308 basic_index = (
309 type_ == self.index_default_access_method and options is None
310 )
311 constraints[index] = {
312 "columns": columns if columns != [None] else [],
313 "orders": orders if orders != [None] else [],
314 "primary_key": primary,
315 "unique": unique,
316 "foreign_key": None,
317 "check": False,
318 "index": True,
319 "type": Index.suffix if basic_index else type_,
320 "definition": definition,
321 "options": options,
322 }
323 return constraints