1from __future__ import annotations
  2
  3import re
  4from dataclasses import dataclass, field
  5from enum import StrEnum
  6from typing import TYPE_CHECKING
  7
  8import sqlparse
  9
 10from ..db import get_connection
 11from ..indexes import Index
 12
 13if TYPE_CHECKING:
 14    from ..connection import DatabaseConnection
 15    from ..utils import CursorWrapper
 16
 17DEFAULT_INDEX_ACCESS_METHOD = "btree"
 18
 19# Index access methods that convergence can create and manage.
 20# Expand when support for new index types ships (e.g. "gin", "gist").
 21MANAGED_INDEX_ACCESS_METHODS: frozenset[str] = frozenset({DEFAULT_INDEX_ACCESS_METHOD})
 22
 23
 24class ConType(StrEnum):
 25    """Postgres pg_constraint.contype values."""
 26
 27    PRIMARY = "p"
 28    UNIQUE = "u"
 29    CHECK = "c"
 30    FOREIGN_KEY = "f"
 31    EXCLUSION = "x"
 32
 33    @property
 34    def label(self) -> str:
 35        return _CONTYPE_LABELS[self]
 36
 37
 38_CONTYPE_LABELS: dict[ConType, str] = {
 39    ConType.PRIMARY: "primary key",
 40    ConType.UNIQUE: "unique",
 41    ConType.CHECK: "check",
 42    ConType.FOREIGN_KEY: "foreign key",
 43    ConType.EXCLUSION: "exclusion",
 44}
 45
 46# Constraint types that convergence can create and manage.
 47# Expand when support for new constraint types ships.
 48MANAGED_CONSTRAINT_TYPES: frozenset[ConType] = frozenset(
 49    {ConType.UNIQUE, ConType.CHECK, ConType.FOREIGN_KEY}
 50)
 51
 52
 53@dataclass
 54class ColumnState:
 55    """A column from pg_attribute."""
 56
 57    type: str
 58    not_null: bool
 59
 60
 61@dataclass
 62class IndexState:
 63    """An index from pg_index + pg_am."""
 64
 65    columns: list[str]
 66    access_method: str = DEFAULT_INDEX_ACCESS_METHOD
 67    is_unique: bool = False
 68    is_valid: bool = True
 69    definition: str | None = None
 70
 71
 72@dataclass
 73class ConstraintState:
 74    """A constraint from pg_constraint.
 75
 76    All constraint types use this single class, matching Postgres's
 77    pg_constraint catalog. FK-specific fields (target_table, target_column)
 78    are only populated for foreign key constraints.
 79    """
 80
 81    constraint_type: ConType
 82    columns: list[str]
 83    validated: bool = True
 84    definition: str | None = None
 85    target_table: str | None = None  # FK only
 86    target_column: str | None = None  # FK only
 87
 88
 89@dataclass
 90class TableState:
 91    """Raw database state for a single table.
 92
 93    Mirrors Postgres's catalog structure:
 94    - columns from pg_attribute
 95    - indexes from pg_index + pg_am
 96    - constraints from pg_constraint (all types in one dict)
 97    """
 98
 99    exists: bool = True
100    columns: dict[str, ColumnState] = field(default_factory=dict)
101    indexes: dict[str, IndexState] = field(default_factory=dict)
102    constraints: dict[str, ConstraintState] = field(default_factory=dict)
103
104
105def introspect_table(
106    conn: DatabaseConnection, cursor: CursorWrapper, table_name: str
107) -> TableState:
108    """Query the database and return the raw state of a table."""
109    actual_columns = _get_columns(cursor, table_name)
110    if not actual_columns:
111        return TableState(exists=False)
112
113    raw = conn.get_constraints(cursor, table_name)
114
115    indexes: dict[str, IndexState] = {}
116    constraints: dict[str, ConstraintState] = {}
117
118    for name, info in raw.items():
119        raw_contype = info.get("contype")
120
121        # Map raw contype to ConType enum if it's a known constraint type
122        contype: ConType | None = None
123        if raw_contype:
124            try:
125                contype = ConType(raw_contype)
126            except ValueError:
127                pass
128
129        if contype in (
130            ConType.PRIMARY,
131            ConType.UNIQUE,
132            ConType.CHECK,
133            ConType.EXCLUSION,
134        ):
135            constraints[name] = ConstraintState(
136                constraint_type=contype,
137                columns=list(info.get("columns") or []),
138                validated=info.get("validated", True),
139                definition=info.get("definition"),
140            )
141        elif contype == ConType.FOREIGN_KEY:
142            fk_target = info.get("foreign_key", ())
143            fk_cols = info.get("columns", [])
144            if len(fk_cols) == 1 and len(fk_target) == 2:
145                constraints[name] = ConstraintState(
146                    constraint_type=ConType.FOREIGN_KEY,
147                    columns=fk_cols,
148                    validated=info.get("validated", True),
149                    definition=info.get("definition"),
150                    target_table=fk_target[0],
151                    target_column=fk_target[1],
152                )
153        elif info.get("index"):
154            # get_constraints() encodes basic btree indexes as Index.suffix ("idx")
155            # and non-btree indexes as their raw pg_am.amname. Reverse that here.
156            raw_type = info.get("type", DEFAULT_INDEX_ACCESS_METHOD)
157            access_method = (
158                DEFAULT_INDEX_ACCESS_METHOD if raw_type == Index.suffix else raw_type
159            )
160            indexes[name] = IndexState(
161                columns=list(info.get("columns") or []),
162                access_method=access_method,
163                is_unique=info.get("unique", False),
164                is_valid=info.get("valid", True),
165                definition=info.get("definition"),
166            )
167
168    return TableState(
169        exists=True,
170        columns=actual_columns,
171        indexes=indexes,
172        constraints=constraints,
173    )
174
175
176def get_unknown_tables(conn: DatabaseConnection | None = None) -> list[str]:
177    """Return sorted list of database tables not managed by any Plain model."""
178    from ..migrations.recorder import MIGRATION_TABLE_NAME
179
180    if conn is None:
181        conn = get_connection()
182    return sorted(
183        set(conn.table_names()) - set(conn.plain_table_names()) - {MIGRATION_TABLE_NAME}
184    )
185
186
187def _strip_balanced_parens(s: str) -> str:
188    """Strip redundant outermost parentheses when they wrap the entire expression."""
189    while s.startswith("(") and s.endswith(")"):
190        inner = s[1:-1]
191        depth = 0
192        balanced = True
193        for ch in inner:
194            if ch == "(":
195                depth += 1
196            elif ch == ")":
197                depth -= 1
198            if depth < 0:
199                balanced = False
200                break
201        if balanced and depth == 0:
202            s = inner.strip()
203        else:
204            break
205    return s
206
207
208def _normalize_sql(s: str) -> str:
209    """Lowercase keywords/identifiers, strip quotes, collapse whitespace."""
210    s = sqlparse.format(
211        s, keyword_case="lower", identifier_case="lower", strip_whitespace=True
212    )
213    s = s.replace('"', "")
214    return re.sub(r"\s+", " ", s).strip()
215
216
217def _strip_type_casts(s: str) -> str:
218    """Strip PostgreSQL type casts (e.g. ''::text, 0::integer).
219
220    PostgreSQL adds explicit casts to stored definitions (pg_get_indexdef,
221    pg_get_constraintdef) but the ORM compiler omits them.  Only used for
222    expression/condition comparison where the two generators diverge.
223    """
224    return re.sub(r"::\w+", "", s)
225
226
227def normalize_check_definition(s: str) -> str:
228    """Normalize a CHECK/condition definition for comparison.
229
230    Strips the CHECK(...) wrapper, redundant parentheses, and PG type casts
231    so that pg_get_constraintdef/pg_get_indexdef output and model-generated
232    SQL can be compared.
233    """
234    s = _normalize_sql(s)
235    s = _strip_type_casts(s)
236    # Strip outer check(...)
237    if s.startswith("check"):
238        s = s[5:].strip()
239        if s.startswith("(") and s.endswith(")"):
240            s = s[1:-1].strip()
241    s = _strip_balanced_parens(s)
242    return s
243
244
245def normalize_unique_definition(s: str) -> str:
246    """Normalize a UNIQUE constraint definition for comparison.
247
248    Strips the UNIQUE keyword so that pg_get_constraintdef output and
249    model-generated definitions can be compared.  Handles INCLUDE and
250    DEFERRABLE clauses that PostgreSQL emits.
251    """
252    s = _normalize_sql(s)
253    if s.startswith("unique"):
254        s = s[6:].strip()
255    return s
256
257
258def normalize_expression(s: str) -> str:
259    """Normalize an index expression for comparison.
260
261    Lowercases, strips quotes, collapses whitespace, and strips redundant
262    outer parentheses.  Used for comparing the expression portion of index
263    definitions (e.g. 'LOWER("col")' vs 'lower(col)').
264    """
265    return _strip_balanced_parens(_normalize_sql(s))
266
267
268def normalize_index_definition(s: str) -> str:
269    """Extract and normalize the expression part of a CREATE INDEX definition.
270
271    Strips the CREATE INDEX ... ON table [USING method] prefix, leaving just
272    the expression spec so that pg_get_indexdef output and model-generated SQL
273    can be compared.
274
275    Example: 'CREATE INDEX foo ON bar USING btree (upper(email))'
276           → '(upper(email))'
277    """
278    s = _normalize_sql(s)
279    # Strip prefix: find USING <method> or fall back to first ( after ON
280    m = re.search(r"\busing \w+ ", s)
281    if m:
282        s = s[m.end() :]
283    else:
284        on_pos = s.find(" on ")
285        if on_pos >= 0:
286            paren = s.find("(", on_pos)
287            if paren >= 0:
288                s = s[paren:]
289    # Strip redundant outer parens — model may generate ((UPPER(col)))
290    # while DB has (upper(col))
291    s = _strip_balanced_parens(s)
292    return s
293
294
295def _get_columns(cursor: CursorWrapper, table_name: str) -> dict[str, ColumnState]:
296    """Return {column_name: ColumnState} from the actual DB."""
297    cursor.execute(
298        """
299        SELECT a.attname, format_type(a.atttypid, a.atttypmod), a.attnotnull
300        FROM pg_attribute a
301        JOIN pg_class c ON a.attrelid = c.oid
302        WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
303          AND a.attnum > 0 AND NOT a.attisdropped
304        ORDER BY a.attnum
305        """,
306        [table_name],
307    )
308    return {
309        name: ColumnState(type=type_str, not_null=is_not_null)
310        for name, type_str, is_not_null in cursor.fetchall()
311    }