1from __future__ import annotations
2
3from typing import Any, NamedTuple
4
5from plain.models.backends.base.introspection import BaseDatabaseIntrospection
6from plain.models.backends.utils import CursorWrapper
7from plain.models.indexes import Index
8
9
10class FieldInfo(NamedTuple):
11 """PostgreSQL-specific FieldInfo extending base with autofield and comment."""
12
13 # Fields from BaseFieldInfo
14 name: str
15 type_code: Any
16 display_size: int | None
17 internal_size: int | None
18 precision: int | None
19 scale: int | None
20 null_ok: bool | None
21 default: Any
22 collation: str | None
23 # PostgreSQL-specific extensions
24 is_autofield: bool
25 comment: str | None
26
27
28class TableInfo(NamedTuple):
29 """PostgreSQL-specific TableInfo extending base with comment support."""
30
31 # Fields from BaseTableInfo
32 name: str
33 type: str
34 # PostgreSQL-specific extension
35 comment: str | None
36
37
38class DatabaseIntrospection(BaseDatabaseIntrospection):
39 # Maps type codes to Plain Field types.
40 data_types_reverse = {
41 16: "BooleanField",
42 17: "BinaryField",
43 20: "BigIntegerField",
44 21: "SmallIntegerField",
45 23: "IntegerField",
46 25: "TextField",
47 700: "FloatField",
48 701: "FloatField",
49 869: "GenericIPAddressField",
50 1042: "CharField", # blank-padded
51 1043: "CharField",
52 1082: "DateField",
53 1083: "TimeField",
54 1114: "DateTimeField",
55 1184: "DateTimeField",
56 1186: "DurationField",
57 1266: "TimeField",
58 1700: "DecimalField",
59 2950: "UUIDField",
60 3802: "JSONField",
61 }
62 # A hook for subclasses.
63 index_default_access_method = "btree"
64
65 ignored_tables: list[str] = []
66
67 def get_field_type(self, data_type: Any, description: Any) -> str:
68 field_type = super().get_field_type(data_type, description)
69 if description.is_autofield or (
70 # Required for pre-Plain 4.1 serial columns.
71 description.default and "nextval" in description.default
72 ):
73 if field_type == "BigIntegerField":
74 return "PrimaryKeyField"
75 return field_type
76
77 def get_table_list(self, cursor: CursorWrapper) -> list[TableInfo]:
78 """Return a list of table and view names in the current database."""
79 cursor.execute(
80 """
81 SELECT
82 c.relname,
83 CASE
84 WHEN c.relispartition THEN 'p'
85 WHEN c.relkind IN ('m', 'v') THEN 'v'
86 ELSE 't'
87 END,
88 obj_description(c.oid, 'pg_class')
89 FROM pg_catalog.pg_class c
90 LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
91 WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
92 AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
93 AND pg_catalog.pg_table_is_visible(c.oid)
94 """
95 )
96 return [
97 TableInfo(*row)
98 for row in cursor.fetchall()
99 if row[0] not in self.ignored_tables
100 ]
101
102 def get_table_description(
103 self, cursor: CursorWrapper, table_name: str
104 ) -> list[FieldInfo]:
105 """
106 Return a description of the table with the DB-API cursor.description
107 interface.
108 """
109 # Query the pg_catalog tables as cursor.description does not reliably
110 # return the nullable property and information_schema.columns does not
111 # contain details of materialized views.
112 cursor.execute(
113 """
114 SELECT
115 a.attname AS column_name,
116 NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable,
117 pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
118 CASE WHEN collname = 'default' THEN NULL ELSE collname END AS collation,
119 a.attidentity != '' AS is_autofield,
120 col_description(a.attrelid, a.attnum) AS column_comment
121 FROM pg_attribute a
122 LEFT JOIN pg_attrdef ad ON a.attrelid = ad.adrelid AND a.attnum = ad.adnum
123 LEFT JOIN pg_collation co ON a.attcollation = co.oid
124 JOIN pg_type t ON a.atttypid = t.oid
125 JOIN pg_class c ON a.attrelid = c.oid
126 JOIN pg_namespace n ON c.relnamespace = n.oid
127 WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
128 AND c.relname = %s
129 AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
130 AND pg_catalog.pg_table_is_visible(c.oid)
131 """,
132 [table_name],
133 )
134 field_map = {line[0]: line[1:] for line in cursor.fetchall()}
135 cursor.execute(
136 f"SELECT * FROM {self.connection.ops.quote_name(table_name)} LIMIT 1"
137 )
138 return [
139 FieldInfo(
140 line.name,
141 line.type_code,
142 line.internal_size if line.display_size is None else line.display_size,
143 line.internal_size,
144 line.precision,
145 line.scale,
146 *field_map[line.name],
147 )
148 for line in cursor.description
149 ]
150
151 def get_sequences(
152 self, cursor: CursorWrapper, table_name: str, table_fields: tuple[Any, ...] = ()
153 ) -> list[dict[str, Any]]:
154 cursor.execute(
155 """
156 SELECT
157 s.relname AS sequence_name,
158 a.attname AS colname
159 FROM
160 pg_class s
161 JOIN pg_depend d ON d.objid = s.oid
162 AND d.classid = 'pg_class'::regclass
163 AND d.refclassid = 'pg_class'::regclass
164 JOIN pg_attribute a ON d.refobjid = a.attrelid
165 AND d.refobjsubid = a.attnum
166 JOIN pg_class tbl ON tbl.oid = d.refobjid
167 AND tbl.relname = %s
168 AND pg_catalog.pg_table_is_visible(tbl.oid)
169 WHERE
170 s.relkind = 'S';
171 """,
172 [table_name],
173 )
174 return [
175 {"name": row[0], "table": table_name, "column": row[1]}
176 for row in cursor.fetchall()
177 ]
178
179 def get_relations(
180 self, cursor: CursorWrapper, table_name: str
181 ) -> dict[str, tuple[str, str]]:
182 """
183 Return a dictionary of {field_name: (field_name_other_table, other_table)}
184 representing all foreign keys in the given table.
185 """
186 cursor.execute(
187 """
188 SELECT a1.attname, c2.relname, a2.attname
189 FROM pg_constraint con
190 LEFT JOIN pg_class c1 ON con.conrelid = c1.oid
191 LEFT JOIN pg_class c2 ON con.confrelid = c2.oid
192 LEFT JOIN
193 pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1]
194 LEFT JOIN
195 pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1]
196 WHERE
197 c1.relname = %s AND
198 con.contype = 'f' AND
199 c1.relnamespace = c2.relnamespace AND
200 pg_catalog.pg_table_is_visible(c1.oid)
201 """,
202 [table_name],
203 )
204 return {row[0]: (row[2], row[1]) for row in cursor.fetchall()}
205
206 def get_constraints(
207 self, cursor: CursorWrapper, table_name: str
208 ) -> dict[str, dict[str, Any]]:
209 """
210 Retrieve any constraints or keys (unique, pk, fk, check, index) across
211 one or more columns. Also retrieve the definition of expression-based
212 indexes.
213 """
214 constraints: dict[str, dict[str, Any]] = {}
215 # Loop over the key table, collecting things as constraints. The column
216 # array must return column names in the same order in which they were
217 # created.
218 cursor.execute(
219 """
220 SELECT
221 c.conname,
222 array(
223 SELECT attname
224 FROM unnest(c.conkey) WITH ORDINALITY cols(colid, arridx)
225 JOIN pg_attribute AS ca ON cols.colid = ca.attnum
226 WHERE ca.attrelid = c.conrelid
227 ORDER BY cols.arridx
228 ),
229 c.contype,
230 (SELECT fkc.relname || '.' || fka.attname
231 FROM pg_attribute AS fka
232 JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
233 WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]),
234 cl.reloptions
235 FROM pg_constraint AS c
236 JOIN pg_class AS cl ON c.conrelid = cl.oid
237 WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
238 """,
239 [table_name],
240 )
241 for constraint, columns, kind, used_cols, options in cursor.fetchall():
242 constraints[constraint] = {
243 "columns": columns,
244 "primary_key": kind == "p",
245 "unique": kind in ["p", "u"],
246 "foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None,
247 "check": kind == "c",
248 "index": False,
249 "definition": None,
250 "options": options,
251 }
252 # Now get indexes
253 cursor.execute(
254 """
255 SELECT
256 indexname,
257 array_agg(attname ORDER BY arridx),
258 indisunique,
259 indisprimary,
260 array_agg(ordering ORDER BY arridx),
261 amname,
262 exprdef,
263 s2.attoptions
264 FROM (
265 SELECT
266 c2.relname as indexname, idx.*, attr.attname, am.amname,
267 CASE
268 WHEN idx.indexprs IS NOT NULL THEN
269 pg_get_indexdef(idx.indexrelid)
270 END AS exprdef,
271 CASE am.amname
272 WHEN %s THEN
273 CASE (option & 1)
274 WHEN 1 THEN 'DESC' ELSE 'ASC'
275 END
276 END as ordering,
277 c2.reloptions as attoptions
278 FROM (
279 SELECT *
280 FROM
281 pg_index i,
282 unnest(i.indkey, i.indoption)
283 WITH ORDINALITY koi(key, option, arridx)
284 ) idx
285 LEFT JOIN pg_class c ON idx.indrelid = c.oid
286 LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
287 LEFT JOIN pg_am am ON c2.relam = am.oid
288 LEFT JOIN
289 pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
290 WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
291 ) s2
292 GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
293 """,
294 [self.index_default_access_method, table_name],
295 )
296 for (
297 index,
298 columns,
299 unique,
300 primary,
301 orders,
302 type_,
303 definition,
304 options,
305 ) in cursor.fetchall():
306 if index not in constraints:
307 basic_index = (
308 type_ == self.index_default_access_method and options is None
309 )
310 constraints[index] = {
311 "columns": columns if columns != [None] else [],
312 "orders": orders if orders != [None] else [],
313 "primary_key": primary,
314 "unique": unique,
315 "foreign_key": None,
316 "check": False,
317 "index": True,
318 "type": Index.suffix if basic_index else type_,
319 "definition": definition,
320 "options": options,
321 }
322 return constraints