1from __future__ import annotations
2
3import re
4from dataclasses import dataclass, field
5from enum import StrEnum
6from typing import TYPE_CHECKING
7
8import sqlparse
9
10from ..db import get_connection
11from ..indexes import Index
12
13if TYPE_CHECKING:
14 from ..connection import DatabaseConnection
15 from ..utils import CursorWrapper
16
17DEFAULT_INDEX_ACCESS_METHOD = "btree"
18
19# Index access methods that convergence can create and manage.
20# Expand when support for new index types ships (e.g. "gin", "gist").
21MANAGED_INDEX_ACCESS_METHODS: frozenset[str] = frozenset({DEFAULT_INDEX_ACCESS_METHOD})
22
23
24class ConType(StrEnum):
25 """Postgres pg_constraint.contype values."""
26
27 PRIMARY = "p"
28 UNIQUE = "u"
29 CHECK = "c"
30 FOREIGN_KEY = "f"
31 EXCLUSION = "x"
32
33 @property
34 def label(self) -> str:
35 return _CONTYPE_LABELS[self]
36
37
38_CONTYPE_LABELS: dict[ConType, str] = {
39 ConType.PRIMARY: "primary key",
40 ConType.UNIQUE: "unique",
41 ConType.CHECK: "check",
42 ConType.FOREIGN_KEY: "foreign key",
43 ConType.EXCLUSION: "exclusion",
44}
45
46# Constraint types that convergence can create and manage.
47# Expand when support for new constraint types ships.
48MANAGED_CONSTRAINT_TYPES: frozenset[ConType] = frozenset(
49 {ConType.UNIQUE, ConType.CHECK, ConType.FOREIGN_KEY}
50)
51
52
53@dataclass
54class ColumnState:
55 """A column from pg_attribute."""
56
57 type: str
58 not_null: bool
59
60
61@dataclass
62class IndexState:
63 """An index from pg_index + pg_am."""
64
65 columns: list[str]
66 access_method: str = DEFAULT_INDEX_ACCESS_METHOD
67 is_unique: bool = False
68 is_valid: bool = True
69 definition: str | None = None
70
71
72@dataclass
73class ConstraintState:
74 """A constraint from pg_constraint.
75
76 All constraint types use this single class, matching Postgres's
77 pg_constraint catalog. FK-specific fields (target_table, target_column)
78 are only populated for foreign key constraints.
79 """
80
81 constraint_type: ConType
82 columns: list[str]
83 validated: bool = True
84 definition: str | None = None
85 target_table: str | None = None # FK only
86 target_column: str | None = None # FK only
87
88
89@dataclass
90class TableState:
91 """Raw database state for a single table.
92
93 Mirrors Postgres's catalog structure:
94 - columns from pg_attribute
95 - indexes from pg_index + pg_am
96 - constraints from pg_constraint (all types in one dict)
97 """
98
99 exists: bool = True
100 columns: dict[str, ColumnState] = field(default_factory=dict)
101 indexes: dict[str, IndexState] = field(default_factory=dict)
102 constraints: dict[str, ConstraintState] = field(default_factory=dict)
103
104
105def introspect_table(
106 conn: DatabaseConnection, cursor: CursorWrapper, table_name: str
107) -> TableState:
108 """Query the database and return the raw state of a table."""
109 actual_columns = _get_columns(cursor, table_name)
110 if not actual_columns:
111 return TableState(exists=False)
112
113 raw = conn.get_constraints(cursor, table_name)
114
115 indexes: dict[str, IndexState] = {}
116 constraints: dict[str, ConstraintState] = {}
117
118 for name, info in raw.items():
119 raw_contype = info.get("contype")
120
121 # Map raw contype to ConType enum if it's a known constraint type
122 contype: ConType | None = None
123 if raw_contype:
124 try:
125 contype = ConType(raw_contype)
126 except ValueError:
127 pass
128
129 if contype in (
130 ConType.PRIMARY,
131 ConType.UNIQUE,
132 ConType.CHECK,
133 ConType.EXCLUSION,
134 ):
135 constraints[name] = ConstraintState(
136 constraint_type=contype,
137 columns=list(info.get("columns") or []),
138 validated=info.get("validated", True),
139 definition=info.get("definition"),
140 )
141 elif contype == ConType.FOREIGN_KEY:
142 fk_target = info.get("foreign_key", ())
143 fk_cols = info.get("columns", [])
144 if len(fk_cols) == 1 and len(fk_target) == 2:
145 constraints[name] = ConstraintState(
146 constraint_type=ConType.FOREIGN_KEY,
147 columns=fk_cols,
148 validated=info.get("validated", True),
149 definition=info.get("definition"),
150 target_table=fk_target[0],
151 target_column=fk_target[1],
152 )
153 elif info.get("index"):
154 # get_constraints() encodes basic btree indexes as Index.suffix ("idx")
155 # and non-btree indexes as their raw pg_am.amname. Reverse that here.
156 raw_type = info.get("type", DEFAULT_INDEX_ACCESS_METHOD)
157 access_method = (
158 DEFAULT_INDEX_ACCESS_METHOD if raw_type == Index.suffix else raw_type
159 )
160 indexes[name] = IndexState(
161 columns=list(info.get("columns") or []),
162 access_method=access_method,
163 is_unique=info.get("unique", False),
164 is_valid=info.get("valid", True),
165 definition=info.get("definition"),
166 )
167
168 return TableState(
169 exists=True,
170 columns=actual_columns,
171 indexes=indexes,
172 constraints=constraints,
173 )
174
175
176def get_unknown_tables(conn: DatabaseConnection | None = None) -> list[str]:
177 """Return sorted list of database tables not managed by any Plain model."""
178 from ..migrations.recorder import MIGRATION_TABLE_NAME
179
180 if conn is None:
181 conn = get_connection()
182 return sorted(
183 set(conn.table_names()) - set(conn.plain_table_names()) - {MIGRATION_TABLE_NAME}
184 )
185
186
187def _strip_balanced_parens(s: str) -> str:
188 """Strip redundant outermost parentheses when they wrap the entire expression."""
189 while s.startswith("(") and s.endswith(")"):
190 inner = s[1:-1]
191 depth = 0
192 balanced = True
193 for ch in inner:
194 if ch == "(":
195 depth += 1
196 elif ch == ")":
197 depth -= 1
198 if depth < 0:
199 balanced = False
200 break
201 if balanced and depth == 0:
202 s = inner.strip()
203 else:
204 break
205 return s
206
207
208def _normalize_sql(s: str) -> str:
209 """Lowercase keywords/identifiers, strip quotes, collapse whitespace."""
210 s = sqlparse.format(
211 s, keyword_case="lower", identifier_case="lower", strip_whitespace=True
212 )
213 s = s.replace('"', "")
214 return re.sub(r"\s+", " ", s).strip()
215
216
217def _strip_type_casts(s: str) -> str:
218 """Strip PostgreSQL type casts (e.g. ''::text, 0::integer).
219
220 PostgreSQL adds explicit casts to stored definitions (pg_get_indexdef,
221 pg_get_constraintdef) but the ORM compiler omits them. Only used for
222 expression/condition comparison where the two generators diverge.
223 """
224 return re.sub(r"::\w+", "", s)
225
226
227def normalize_check_definition(s: str) -> str:
228 """Normalize a CHECK/condition definition for comparison.
229
230 Strips the CHECK(...) wrapper, redundant parentheses, and PG type casts
231 so that pg_get_constraintdef/pg_get_indexdef output and model-generated
232 SQL can be compared.
233 """
234 s = _normalize_sql(s)
235 s = _strip_type_casts(s)
236 # Strip outer check(...)
237 if s.startswith("check"):
238 s = s[5:].strip()
239 if s.startswith("(") and s.endswith(")"):
240 s = s[1:-1].strip()
241 s = _strip_balanced_parens(s)
242 return s
243
244
245def normalize_unique_definition(s: str) -> str:
246 """Normalize a UNIQUE constraint definition for comparison.
247
248 Strips the UNIQUE keyword so that pg_get_constraintdef output and
249 model-generated definitions can be compared. Handles INCLUDE and
250 DEFERRABLE clauses that PostgreSQL emits.
251 """
252 s = _normalize_sql(s)
253 if s.startswith("unique"):
254 s = s[6:].strip()
255 return s
256
257
258def normalize_expression(s: str) -> str:
259 """Normalize an index expression for comparison.
260
261 Lowercases, strips quotes, collapses whitespace, and strips redundant
262 outer parentheses. Used for comparing the expression portion of index
263 definitions (e.g. 'LOWER("col")' vs 'lower(col)').
264 """
265 return _strip_balanced_parens(_normalize_sql(s))
266
267
268def normalize_index_definition(s: str) -> str:
269 """Extract and normalize the expression part of a CREATE INDEX definition.
270
271 Strips the CREATE INDEX ... ON table [USING method] prefix, leaving just
272 the expression spec so that pg_get_indexdef output and model-generated SQL
273 can be compared.
274
275 Example: 'CREATE INDEX foo ON bar USING btree (upper(email))'
276 → '(upper(email))'
277 """
278 s = _normalize_sql(s)
279 # Strip prefix: find USING <method> or fall back to first ( after ON
280 m = re.search(r"\busing \w+ ", s)
281 if m:
282 s = s[m.end() :]
283 else:
284 on_pos = s.find(" on ")
285 if on_pos >= 0:
286 paren = s.find("(", on_pos)
287 if paren >= 0:
288 s = s[paren:]
289 # Strip redundant outer parens — model may generate ((UPPER(col)))
290 # while DB has (upper(col))
291 s = _strip_balanced_parens(s)
292 return s
293
294
295def _get_columns(cursor: CursorWrapper, table_name: str) -> dict[str, ColumnState]:
296 """Return {column_name: ColumnState} from the actual DB."""
297 cursor.execute(
298 """
299 SELECT a.attname, format_type(a.atttypid, a.atttypmod), a.attnotnull
300 FROM pg_attribute a
301 JOIN pg_class c ON a.attrelid = c.oid
302 WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
303 AND a.attnum > 0 AND NOT a.attisdropped
304 ORDER BY a.attnum
305 """,
306 [table_name],
307 )
308 return {
309 name: ColumnState(type=type_str, not_null=is_not_null)
310 for name, type_str, is_not_null in cursor.fetchall()
311 }