1from __future__ import annotations
2
3import re
4from dataclasses import dataclass, field
5from enum import StrEnum
6from functools import cached_property
7from typing import TYPE_CHECKING, Any
8
9from ..constraints import CheckConstraint, UniqueConstraint
10from ..ddl import compile_expression_sql, compile_index_expressions_sql
11from ..dialect import quote_name
12from ..fields.related import ForeignKeyField
13from ..indexes import Index
14from ..introspection import (
15 MANAGED_CONSTRAINT_TYPES,
16 MANAGED_INDEX_ACCESS_METHODS,
17 ConstraintState,
18 ConType,
19 TableState,
20 introspect_table,
21 normalize_check_definition,
22 normalize_expression,
23 normalize_index_definition,
24 normalize_unique_definition,
25)
26
27if TYPE_CHECKING:
28 from ..base import Model
29 from ..connection import DatabaseConnection
30 from ..expressions import Expression, ReplaceableExpression
31 from ..query_utils import Q
32 from ..utils import CursorWrapper
33
34
35# Drift types — semantic descriptions of schema differences
36
37
38class DriftKind(StrEnum):
39 MISSING = "missing"
40 INVALID = "invalid"
41 CHANGED = "changed"
42 RENAMED = "renamed"
43 UNDECLARED = "undeclared"
44 UNVALIDATED = "unvalidated"
45
46
47@dataclass
48class IndexDrift:
49 """A schema difference for an index."""
50
51 kind: DriftKind
52 table: str
53 index: Index | None = None
54 model: type[Model] | None = None
55 old_name: str | None = None
56 new_name: str | None = None
57 name: str | None = None
58
59 def describe(self) -> str:
60 match self.kind:
61 case DriftKind.MISSING:
62 assert self.index is not None
63 return f"{self.table}: index {self.index.name} missing"
64 case DriftKind.INVALID:
65 assert self.index is not None
66 return f"{self.table}: index {self.index.name} INVALID"
67 case DriftKind.CHANGED:
68 assert self.index is not None
69 return f"{self.table}: index {self.index.name} definition changed"
70 case DriftKind.RENAMED:
71 return f"{self.table}: index {self.old_name} → {self.new_name}"
72 case _:
73 return f"{self.table}: index {self.name} not declared"
74
75
76@dataclass
77class ConstraintDrift:
78 """A schema difference for a constraint."""
79
80 kind: DriftKind
81 table: str
82 constraint: CheckConstraint | UniqueConstraint | None = None
83 model: type[Model] | None = None
84 old_name: str | None = None
85 new_name: str | None = None
86 name: str | None = None
87
88 def describe(self) -> str:
89 match self.kind:
90 case DriftKind.MISSING:
91 assert self.constraint is not None
92 return f"{self.table}: constraint {self.constraint.name} missing"
93 case DriftKind.UNVALIDATED:
94 return f"{self.table}: constraint {self.name} NOT VALID"
95 case DriftKind.CHANGED:
96 assert self.constraint is not None
97 return f"{self.table}: constraint {self.constraint.name} definition changed"
98 case DriftKind.RENAMED:
99 return f"{self.table}: constraint {self.old_name} → {self.new_name}"
100 case _:
101 return f"{self.table}: constraint {self.name} not declared"
102
103
104@dataclass
105class ForeignKeyDrift:
106 """A schema difference for a foreign key constraint."""
107
108 kind: DriftKind
109 table: str
110 name: str | None = None
111 column: str | None = None
112 target_table: str | None = None
113 target_column: str | None = None
114
115 def describe(self) -> str:
116 match self.kind:
117 case DriftKind.MISSING:
118 return f"{self.table}: FK {self.name} missing ({self.column} → {self.target_table}.{self.target_column})"
119 case DriftKind.UNVALIDATED:
120 return f"{self.table}: FK {self.name} NOT VALID"
121 case _:
122 return f"{self.table}: FK {self.name} not declared"
123
124
125@dataclass
126class NullabilityDrift:
127 """Mismatch between model and DB column nullability."""
128
129 table: str
130 column: str
131 model_allows_null: bool
132 has_null_rows: bool = False # Only checked when model_allows_null is False
133
134 def describe(self) -> str:
135 if not self.model_allows_null:
136 if self.has_null_rows:
137 return (
138 f"{self.table}: column {self.column} allows NULL (NULL rows exist)"
139 )
140 return f"{self.table}: column {self.column} allows NULL"
141 return f"{self.table}: column {self.column} is NOT NULL, model allows NULL"
142
143
144type Drift = IndexDrift | ConstraintDrift | ForeignKeyDrift | NullabilityDrift
145
146
147# Status objects — analysis results with optional drift
148
149
150@dataclass
151class ColumnStatus:
152 name: str
153 field_name: str
154 type: str
155 nullable: bool
156 primary_key: bool
157 pk_suffix: str
158 issue: str | None = None
159 drift: NullabilityDrift | None = None
160
161
162@dataclass
163class IndexStatus:
164 name: str
165 fields: list[str] = field(default_factory=list)
166 issue: str | None = None
167 drift: IndexDrift | None = None
168 access_method: str | None = None # set for unmanaged indexes (display only)
169
170
171@dataclass
172class ConstraintStatus:
173 name: str
174 constraint_type: ConType
175 fields: list[str] = field(default_factory=list)
176 issue: str | None = None
177 drift: ConstraintDrift | ForeignKeyDrift | IndexDrift | None = None
178
179
180@dataclass
181class ModelAnalysis:
182 label: str
183 table: str
184 table_issues: list[str] = field(default_factory=list)
185 columns: list[ColumnStatus] = field(default_factory=list)
186 indexes: list[IndexStatus] = field(default_factory=list)
187 constraints: list[ConstraintStatus] = field(default_factory=list)
188
189 @cached_property
190 def drifts(self) -> list[Drift]:
191 """All detected schema drifts."""
192 result: list[Drift] = []
193 for col in self.columns:
194 if col.drift:
195 result.append(col.drift)
196 for idx in self.indexes:
197 if idx.drift:
198 result.append(idx.drift)
199 for con in self.constraints:
200 if con.drift:
201 result.append(con.drift)
202 return result
203
204 @cached_property
205 def issue_count(self) -> int:
206 """Total issues (table + columns + indexes + constraints)."""
207 count = len(self.table_issues)
208 count += sum(1 for col in self.columns if col.issue)
209 count += sum(1 for idx in self.indexes if idx.issue)
210 count += sum(1 for con in self.constraints if con.issue)
211 return count
212
213 def to_dict(self) -> dict[str, Any]:
214 """Serialize for --json output."""
215 return {
216 "label": self.label,
217 "table": self.table,
218 "table_issues": self.table_issues,
219 "columns": [
220 {
221 "name": col.name,
222 "field_name": col.field_name,
223 "type": col.type,
224 "nullable": col.nullable,
225 "primary_key": col.primary_key,
226 "pk_suffix": col.pk_suffix,
227 "issue": col.issue,
228 "drift": col.drift.describe() if col.drift else None,
229 }
230 for col in self.columns
231 ],
232 "indexes": [
233 {
234 "name": idx.name,
235 "fields": idx.fields,
236 "access_method": idx.access_method,
237 "issue": idx.issue,
238 "drift": idx.drift.describe() if idx.drift else None,
239 }
240 for idx in self.indexes
241 ],
242 "constraints": [
243 {
244 "name": con.name,
245 "constraint_type": con.constraint_type,
246 "type_label": con.constraint_type.label,
247 "fields": con.fields,
248 "issue": con.issue,
249 "drift": con.drift.describe() if con.drift else None,
250 }
251 for con in self.constraints
252 ],
253 }
254
255
256def analyze_model(
257 conn: DatabaseConnection, cursor: CursorWrapper, model: type[Model]
258) -> ModelAnalysis:
259 """Compare a model against the database and classify each difference.
260
261 Introspects the actual table state, compares it against model definitions,
262 and produces a ModelAnalysis where each column/index/constraint carries its
263 issue (if any) and drift object (if schema differs).
264 """
265 table_name = model.model_options.db_table
266 db = introspect_table(conn, cursor, table_name)
267
268 if not db.exists:
269 return ModelAnalysis(
270 label=model.model_options.label,
271 table=table_name,
272 table_issues=["table missing from database"],
273 )
274
275 return ModelAnalysis(
276 label=model.model_options.label,
277 table=table_name,
278 columns=_compare_columns(model, db, table_name, cursor),
279 indexes=_compare_indexes(model, db, table_name),
280 constraints=_compare_constraints(model, db, table_name),
281 )
282
283
284# Column comparison
285
286
287def _column_has_nulls(cursor: CursorWrapper, table: str, column: str) -> bool:
288 """Check whether any NULL values exist in a column."""
289 cursor.execute(
290 f"SELECT 1 FROM {quote_name(table)} WHERE {quote_name(column)} IS NULL LIMIT 1"
291 )
292 return cursor.fetchone() is not None
293
294
295def _compare_columns(
296 model: type[Model], db: TableState, table: str, cursor: CursorWrapper
297) -> list[ColumnStatus]:
298 statuses: list[ColumnStatus] = []
299 expected_col_names: set[str] = set()
300
301 for f in model._model_meta.local_fields:
302 db_type = f.db_type()
303 if db_type is None:
304 continue
305
306 expected_col_names.add(f.column)
307 issue: str | None = None
308 drift: NullabilityDrift | None = None
309
310 if f.column not in db.columns:
311 issue = "missing from database"
312 else:
313 actual = db.columns[f.column]
314 if db_type != actual.type:
315 issue = f"expected {db_type}, actual {actual.type}"
316 elif not f.allow_null and not actual.not_null:
317 # Model says NOT NULL, DB allows NULL — semantic drift
318 has_nulls = _column_has_nulls(cursor, table, f.column)
319 if has_nulls:
320 issue = "expected NOT NULL, actual NULL (NULL rows exist)"
321 else:
322 issue = "expected NOT NULL, actual NULL"
323 drift = NullabilityDrift(
324 table=table,
325 column=f.column,
326 model_allows_null=False,
327 has_null_rows=has_nulls,
328 )
329 elif f.allow_null and actual.not_null:
330 issue = "expected NULL, actual NOT NULL"
331 drift = NullabilityDrift(
332 table=table,
333 column=f.column,
334 model_allows_null=True,
335 )
336
337 pk_suffix = ""
338 if f.primary_key:
339 pk_suffix = f.db_type_suffix() or ""
340
341 assert f.name is not None
342 statuses.append(
343 ColumnStatus(
344 name=f.column,
345 field_name=f.name,
346 type=db_type,
347 nullable=f.allow_null,
348 primary_key=f.primary_key,
349 pk_suffix=pk_suffix,
350 issue=issue,
351 drift=drift,
352 )
353 )
354
355 for col_name in sorted(db.columns.keys() - expected_col_names):
356 actual = db.columns[col_name]
357 statuses.append(
358 ColumnStatus(
359 name=col_name,
360 field_name="",
361 type=actual.type,
362 nullable=not actual.not_null,
363 primary_key=False,
364 pk_suffix="",
365 issue="extra column, not in model",
366 )
367 )
368
369 return statuses
370
371
372# Index comparison with rename detection
373
374
375def _compare_indexes(
376 model: type[Model], db: TableState, table: str
377) -> list[IndexStatus]:
378 statuses: list[IndexStatus] = []
379 missing: list[Index] = []
380 model_index_names = {idx.name for idx in model.model_options.indexes}
381 # Unique indexes are handled by _compare_unique_constraints, not here.
382 # Also exclude indexes that back unique constraints in pg_constraint.
383 unique_constraint_names = {
384 k for k, v in db.constraints.items() if v.constraint_type == ConType.UNIQUE
385 }
386 non_unique_indexes = {
387 k: v
388 for k, v in db.indexes.items()
389 if not v.is_unique and k not in unique_constraint_names
390 }
391
392 for index in model.model_options.indexes:
393 if index.name not in non_unique_indexes:
394 missing.append(index)
395 continue
396
397 db_idx = non_unique_indexes[index.name]
398
399 # Name collision: DB has an unmanaged index type with this name
400 if db_idx.access_method not in MANAGED_INDEX_ACCESS_METHODS:
401 statuses.append(
402 IndexStatus(
403 name=index.name,
404 fields=list(index.fields) if index.fields else [],
405 issue=f"name conflict with {db_idx.access_method} index — rename one to resolve",
406 )
407 )
408 continue
409
410 if not db_idx.is_valid:
411 statuses.append(
412 IndexStatus(
413 name=index.name,
414 fields=list(index.fields) if index.fields else [],
415 issue="INVALID — needs drop and recreate",
416 drift=IndexDrift(
417 kind=DriftKind.INVALID,
418 table=table,
419 index=index,
420 model=model,
421 ),
422 )
423 )
424 continue
425
426 # Check if definition matches
427 if db_idx.definition:
428 issue = _compare_index_definition(model, index, db_idx.definition)
429 if issue:
430 statuses.append(
431 IndexStatus(
432 name=index.name,
433 fields=list(index.fields) if index.fields else [],
434 issue=issue,
435 drift=IndexDrift(
436 kind=DriftKind.CHANGED,
437 table=table,
438 index=index,
439 model=model,
440 ),
441 )
442 )
443 continue
444
445 # Index exists and matches
446 statuses.append(
447 IndexStatus(
448 name=index.name,
449 fields=list(index.fields) if index.fields else [],
450 )
451 )
452
453 # Extra indexes (in DB but not in model)
454 extra_names = sorted(non_unique_indexes.keys() - model_index_names)
455
456 # Only managed index types participate in rename detection
457 managed_extra = [
458 n
459 for n in extra_names
460 if non_unique_indexes[n].access_method in MANAGED_INDEX_ACCESS_METHODS
461 ]
462
463 # Detect renames: match missing and extra by normalized definition
464 renamed_missing: set[str] = set()
465 renamed_extra: set[str] = set()
466
467 missing_by_def: dict[str, list[Index]] = {}
468 for index in missing:
469 norm = normalize_index_definition(index.to_sql(model))
470 missing_by_def.setdefault(norm, []).append(index)
471
472 extra_by_def: dict[str, list[str]] = {}
473 for name in managed_extra:
474 defn = non_unique_indexes[name].definition
475 if defn:
476 norm = normalize_index_definition(defn)
477 extra_by_def.setdefault(norm, []).append(name)
478
479 for norm, m_list in missing_by_def.items():
480 e_list = extra_by_def.get(norm)
481 if e_list and len(m_list) == 1 and len(e_list) == 1:
482 index = m_list[0]
483 old_name = e_list[0]
484 statuses.append(
485 IndexStatus(
486 name=index.name,
487 fields=list(index.fields) if index.fields else [],
488 issue=f"rename from {old_name}",
489 drift=IndexDrift(
490 kind=DriftKind.RENAMED,
491 table=table,
492 old_name=old_name,
493 new_name=index.name,
494 ),
495 )
496 )
497 renamed_missing.add(index.name)
498 renamed_extra.add(old_name)
499
500 # Remaining unmatched missing
501 for index in missing:
502 if index.name not in renamed_missing:
503 statuses.append(
504 IndexStatus(
505 name=index.name,
506 fields=list(index.fields) if index.fields else [],
507 issue="missing from database",
508 drift=IndexDrift(
509 kind=DriftKind.MISSING,
510 table=table,
511 index=index,
512 model=model,
513 ),
514 )
515 )
516
517 # Extra managed indexes are undeclared
518 for name in managed_extra:
519 if name not in renamed_extra:
520 statuses.append(
521 IndexStatus(
522 name=name,
523 fields=non_unique_indexes[name].columns,
524 issue="not in model",
525 drift=IndexDrift(
526 kind=DriftKind.UNDECLARED,
527 table=table,
528 name=name,
529 ),
530 )
531 )
532
533 # Extra unmanaged indexes — informational only, no drift
534 for name in extra_names:
535 idx = non_unique_indexes[name]
536 if idx.access_method not in MANAGED_INDEX_ACCESS_METHODS:
537 statuses.append(
538 IndexStatus(
539 name=name,
540 fields=idx.columns,
541 access_method=idx.access_method,
542 )
543 )
544
545 return statuses
546
547
548def _compare_index_definition(
549 model: type[Model], index: Index, actual_def: str
550) -> str | None:
551 """Compare a model index against its pg_get_indexdef output.
552
553 Returns an issue string if definitions differ, None if they match.
554 """
555 return _compare_parsed_index(
556 model=model,
557 expressions=index.expressions,
558 fields=[name for name, _ in index.fields_orders],
559 opclasses=list(index.opclasses) if index.opclasses else [],
560 condition=index.condition,
561 actual_def=actual_def,
562 )
563
564
565# Constraint comparison
566
567
568def _compare_constraints(
569 model: type[Model], db: TableState, table: str
570) -> list[ConstraintStatus]:
571 statuses: list[ConstraintStatus] = []
572 statuses.extend(_compare_unique_constraints(model, db, table))
573 statuses.extend(_compare_check_constraints(model, db, table))
574 statuses.extend(_compare_foreign_keys(model, db, table))
575
576 # Unmanaged constraint types — informational only, no drift.
577 # Primary keys are also unmanaged but not shown.
578 for name, cs in db.constraints.items():
579 if (
580 cs.constraint_type not in MANAGED_CONSTRAINT_TYPES
581 and cs.constraint_type != ConType.PRIMARY
582 ):
583 statuses.append(
584 ConstraintStatus(
585 name=name,
586 constraint_type=cs.constraint_type,
587 fields=cs.columns,
588 )
589 )
590
591 return statuses
592
593
594def _compare_unique_constraints(
595 model: type[Model], db: TableState, table: str
596) -> list[ConstraintStatus]:
597 statuses: list[ConstraintStatus] = []
598 # Unique constraints from pg_constraint (contype='u')
599 actual_constraints = {
600 k: v for k, v in db.constraints.items() if v.constraint_type == ConType.UNIQUE
601 }
602 # Unique indexes from pg_index that don't have a backing pg_constraint
603 # (e.g. partial/expression unique indexes created with CREATE UNIQUE INDEX)
604 actual_indexes = {
605 k: ConstraintState(
606 constraint_type=ConType.UNIQUE,
607 columns=v.columns,
608 validated=True,
609 definition=v.definition,
610 )
611 for k, v in db.indexes.items()
612 if v.is_unique and k not in actual_constraints
613 }
614 actual = {**actual_constraints, **actual_indexes}
615 model_constraints = [
616 c for c in model.model_options.constraints if isinstance(c, UniqueConstraint)
617 ]
618 expected_names = {c.name for c in model_constraints}
619 extra_names = sorted(actual.keys() - expected_names)
620
621 missing: list[UniqueConstraint] = []
622 for constraint in model_constraints:
623 if constraint.name not in actual:
624 missing.append(constraint)
625 continue
626
627 issue: str | None = None
628 drift: ConstraintDrift | None = None
629
630 if not actual[constraint.name].validated:
631 issue = "NOT VALID — needs validation"
632 drift = ConstraintDrift(
633 kind=DriftKind.UNVALIDATED,
634 table=table,
635 name=constraint.name,
636 )
637 elif constraint.index_only:
638 issue, drift = _compare_index_only_unique(
639 model, constraint, actual[constraint.name], table
640 )
641 elif actual_def := actual[constraint.name].definition:
642 expected_def = _get_expected_unique_definition(model, constraint)
643 if normalize_unique_definition(actual_def) != normalize_unique_definition(
644 expected_def
645 ):
646 issue = f"definition differs: DB has {actual_def!r}, model expects {expected_def!r}"
647 drift = ConstraintDrift(
648 kind=DriftKind.CHANGED,
649 table=table,
650 constraint=constraint,
651 model=model,
652 )
653
654 statuses.append(
655 ConstraintStatus(
656 name=constraint.name,
657 constraint_type=ConType.UNIQUE,
658 fields=list(constraint.fields),
659 issue=issue,
660 drift=drift,
661 )
662 )
663
664 # Detect renames by columns
665 rename_statuses, renamed_missing, renamed_extra = _detect_unique_renames(
666 missing, extra_names, actual, model, table
667 )
668 statuses.extend(rename_statuses)
669
670 for constraint in missing:
671 if constraint.name not in renamed_missing:
672 statuses.append(
673 ConstraintStatus(
674 name=constraint.name,
675 constraint_type=ConType.UNIQUE,
676 fields=list(constraint.fields),
677 issue="missing from database",
678 drift=ConstraintDrift(
679 kind=DriftKind.MISSING,
680 table=table,
681 constraint=constraint,
682 model=model,
683 ),
684 )
685 )
686
687 for name in extra_names:
688 if name not in renamed_extra:
689 # Index-only entries (from pg_index, not pg_constraint) need
690 # IndexDrift so the planner uses DROP INDEX, not DROP CONSTRAINT.
691 undeclared_drift: Drift
692 if name in actual_indexes:
693 undeclared_drift = IndexDrift(
694 kind=DriftKind.UNDECLARED, table=table, name=name
695 )
696 else:
697 undeclared_drift = ConstraintDrift(
698 kind=DriftKind.UNDECLARED, table=table, name=name
699 )
700 statuses.append(
701 ConstraintStatus(
702 name=name,
703 constraint_type=ConType.UNIQUE,
704 fields=actual[name].columns,
705 issue="not in model",
706 drift=undeclared_drift,
707 )
708 )
709
710 return statuses
711
712
713def _compare_check_constraints(
714 model: type[Model], db: TableState, table: str
715) -> list[ConstraintStatus]:
716 statuses: list[ConstraintStatus] = []
717 actual = {
718 k: v for k, v in db.constraints.items() if v.constraint_type == ConType.CHECK
719 }
720 model_constraints = [
721 c for c in model.model_options.constraints if isinstance(c, CheckConstraint)
722 ]
723 expected_names = {c.name for c in model_constraints}
724 extra_names = sorted(actual.keys() - expected_names)
725
726 missing: list[CheckConstraint] = []
727 for constraint in model_constraints:
728 if constraint.name not in actual:
729 missing.append(constraint)
730 continue
731
732 issue: str | None = None
733 drift: ConstraintDrift | None = None
734
735 if not actual[constraint.name].validated:
736 issue = "NOT VALID — needs validation"
737 drift = ConstraintDrift(
738 kind=DriftKind.UNVALIDATED,
739 table=table,
740 name=constraint.name,
741 )
742 elif actual_def := actual[constraint.name].definition:
743 expected_def = _get_expected_check_definition(model, constraint)
744 if normalize_check_definition(actual_def) != normalize_check_definition(
745 expected_def
746 ):
747 issue = f"definition differs: DB has {actual_def!r}, model expects {expected_def!r}"
748 drift = ConstraintDrift(
749 kind=DriftKind.CHANGED,
750 table=table,
751 constraint=constraint,
752 model=model,
753 )
754
755 statuses.append(
756 ConstraintStatus(
757 name=constraint.name,
758 constraint_type=ConType.CHECK,
759 fields=[],
760 issue=issue,
761 drift=drift,
762 )
763 )
764
765 # Detect renames by definition
766 rename_statuses, renamed_missing, renamed_extra = _detect_check_renames(
767 missing, extra_names, actual, model, table
768 )
769 statuses.extend(rename_statuses)
770
771 for constraint in missing:
772 if constraint.name not in renamed_missing:
773 statuses.append(
774 ConstraintStatus(
775 name=constraint.name,
776 constraint_type=ConType.CHECK,
777 fields=[],
778 issue="missing from database",
779 drift=ConstraintDrift(
780 kind=DriftKind.MISSING,
781 table=table,
782 constraint=constraint,
783 model=model,
784 ),
785 )
786 )
787
788 # Build set of framework-owned temp NOT NULL check names so leftover
789 # artifacts from a partially-completed SetNotNullFix are silently
790 # ignored rather than surfaced as undeclared user constraints.
791 internal_checks = {
792 generate_notnull_check_name(table, f.column)
793 for f in model._model_meta.local_fields
794 if f.db_type() is not None
795 }
796
797 for name in extra_names:
798 if name not in renamed_extra and name not in internal_checks:
799 statuses.append(
800 ConstraintStatus(
801 name=name,
802 constraint_type=ConType.CHECK,
803 fields=actual[name].columns,
804 issue="not in model",
805 drift=ConstraintDrift(
806 kind=DriftKind.UNDECLARED,
807 table=table,
808 name=name,
809 ),
810 )
811 )
812
813 return statuses
814
815
816def _compare_foreign_keys(
817 model: type[Model], db: TableState, table: str
818) -> list[ConstraintStatus]:
819 statuses: list[ConstraintStatus] = []
820 actual = {
821 k: v
822 for k, v in db.constraints.items()
823 if v.constraint_type == ConType.FOREIGN_KEY
824 }
825
826 # Build expected FKs from model fields: shape → (field_name, constraint_name)
827 expected_fks: dict[tuple[str, str, str], tuple[str, str]] = {}
828 for f in model._model_meta.local_fields:
829 if isinstance(f, ForeignKeyField) and f.db_constraint:
830 assert f.name is not None
831 to_table = f.target_field.model.model_options.db_table
832 to_column = f.target_field.column
833 constraint_name = generate_fk_constraint_name(
834 table, f.column, to_table, to_column
835 )
836 expected_fks[(f.column, to_table, to_column)] = (f.name, constraint_name)
837
838 # Build actual FKs from DB: shape → (constraint_name, ConstraintState)
839 actual_fk_by_shape: dict[tuple[str, str, str], tuple[str, ConstraintState]] = {}
840 for name, cs in actual.items():
841 if cs.target_table and cs.target_column and cs.columns:
842 actual_fk_by_shape[(cs.columns[0], cs.target_table, cs.target_column)] = (
843 name,
844 cs,
845 )
846
847 matched_fk_names: set[str] = set()
848 for key, (field_name, constraint_name) in expected_fks.items():
849 if match := actual_fk_by_shape.get(key):
850 actual_name, cs = match
851 matched_fk_names.add(actual_name)
852
853 # Check validation state
854 issue: str | None = None
855 drift: ForeignKeyDrift | None = None
856 if not cs.validated:
857 issue = "NOT VALID — needs validation"
858 drift = ForeignKeyDrift(
859 kind=DriftKind.UNVALIDATED,
860 table=table,
861 name=actual_name,
862 )
863
864 statuses.append(
865 ConstraintStatus(
866 name=actual_name,
867 constraint_type=ConType.FOREIGN_KEY,
868 fields=[key[0]],
869 issue=issue,
870 drift=drift,
871 )
872 )
873 else:
874 col, to_table, to_column = key
875 statuses.append(
876 ConstraintStatus(
877 name=f"{field_name} → {to_table}.{to_column}",
878 constraint_type=ConType.FOREIGN_KEY,
879 fields=[col],
880 issue="missing from database",
881 drift=ForeignKeyDrift(
882 kind=DriftKind.MISSING,
883 table=table,
884 name=constraint_name,
885 column=col,
886 target_table=to_table,
887 target_column=to_column,
888 ),
889 )
890 )
891
892 for name in sorted(actual.keys() - matched_fk_names):
893 cs = actual[name]
894 statuses.append(
895 ConstraintStatus(
896 name=name,
897 constraint_type=ConType.FOREIGN_KEY,
898 fields=cs.columns,
899 issue=f"not in model (→ {cs.target_table}.{cs.target_column})",
900 drift=ForeignKeyDrift(
901 kind=DriftKind.UNDECLARED,
902 table=table,
903 name=name,
904 ),
905 )
906 )
907
908 return statuses
909
910
911def generate_notnull_check_name(table: str, column: str) -> str:
912 """Generate a hashed name for the temporary NOT NULL check constraint.
913
914 Used by SetNotNullFix for the CHECK NOT VALID → VALIDATE → SET NOT NULL
915 pattern, and by analysis to recognize (and ignore) leftover temp checks.
916 """
917 from ..utils import generate_identifier_name
918
919 return generate_identifier_name(table, [column], "_notnull")
920
921
922def generate_fk_constraint_name(
923 table: str, column: str, target_table: str, target_column: str
924) -> str:
925 """Generate a deterministic FK constraint name.
926
927 Uses the same naming algorithm as the schema editor so that
928 convergence-created FKs match migration-created ones.
929 """
930 from ..utils import generate_identifier_name, split_identifier
931
932 _, target_table_name = split_identifier(target_table)
933 suffix = f"_fk_{target_table_name}_{target_column}"
934 return generate_identifier_name(table, [column], suffix)
935
936
937def _detect_unique_renames(
938 missing: list[UniqueConstraint],
939 extra_names: list[str],
940 actual_dict: dict[str, ConstraintState],
941 model: type[Model],
942 table: str,
943) -> tuple[list[ConstraintStatus], set[str], set[str]]:
944 """Match missing and extra unique constraints by structure.
945
946 Constraint-backed (not index_only): matched by resolved column tuple.
947 Index-only (condition/expression/opclass): matched by normalized index
948 definition, which captures the full semantics including WHERE clauses,
949 opclasses, and expressions.
950 """
951 statuses: list[ConstraintStatus] = []
952 renamed_missing: set[str] = set()
953 renamed_extra: set[str] = set()
954
955 # Phase 1: Field-based — match by resolved column tuple.
956 # Covers both constraint-backed and index-only field-based constraints.
957 missing_by_cols: dict[tuple[str, ...], list[UniqueConstraint]] = {}
958 for constraint in missing:
959 if not constraint.fields:
960 continue
961 cols = tuple(
962 model._model_meta.get_forward_field(field_name).column
963 for field_name in constraint.fields
964 )
965 missing_by_cols.setdefault(cols, []).append(constraint)
966
967 extra_by_cols: dict[tuple[str, ...], list[str]] = {}
968 for name in extra_names:
969 cols = tuple(actual_dict[name].columns)
970 if cols:
971 extra_by_cols.setdefault(cols, []).append(name)
972
973 for cols, m_list in missing_by_cols.items():
974 e_list = extra_by_cols.get(cols)
975 if e_list and len(m_list) == 1 and len(e_list) == 1:
976 constraint = m_list[0]
977 old_name = e_list[0]
978 # For index-only uniques, same columns isn't enough — the
979 # condition or opclass may have changed. Verify the full
980 # definition matches before accepting a rename, otherwise
981 # let both sides fall through as separate missing + undeclared.
982 if constraint.index_only:
983 old_def = actual_dict[old_name].definition
984 if not old_def or normalize_index_definition(
985 old_def
986 ) != normalize_index_definition(constraint.to_sql(model)):
987 continue
988 DriftType = IndexDrift if constraint.index_only else ConstraintDrift
989 statuses.append(
990 ConstraintStatus(
991 name=constraint.name,
992 constraint_type=ConType.UNIQUE,
993 fields=list(constraint.fields),
994 issue=f"rename from {old_name}",
995 drift=DriftType(
996 kind=DriftKind.RENAMED,
997 table=table,
998 old_name=old_name,
999 new_name=constraint.name,
1000 ),
1001 )
1002 )
1003 renamed_missing.add(constraint.name)
1004 renamed_extra.add(old_name)
1005
1006 # Phase 2: Expression-based — match by normalized index definition.
1007 missing_by_def: dict[str, list[UniqueConstraint]] = {}
1008 for constraint in missing:
1009 if constraint.fields or constraint.name in renamed_missing:
1010 continue
1011 norm = normalize_index_definition(constraint.to_sql(model))
1012 missing_by_def.setdefault(norm, []).append(constraint)
1013
1014 extra_by_def: dict[str, list[str]] = {}
1015 for name in extra_names:
1016 if name in renamed_extra:
1017 continue
1018 defn = actual_dict[name].definition
1019 if defn:
1020 norm = normalize_index_definition(defn)
1021 extra_by_def.setdefault(norm, []).append(name)
1022
1023 for norm, m_list in missing_by_def.items():
1024 e_list = extra_by_def.get(norm)
1025 if e_list and len(m_list) == 1 and len(e_list) == 1:
1026 constraint = m_list[0]
1027 old_name = e_list[0]
1028 # Index-only uniques live as indexes, not constraints, so
1029 # emit IndexDrift so the planner uses ALTER INDEX RENAME.
1030 statuses.append(
1031 ConstraintStatus(
1032 name=constraint.name,
1033 constraint_type=ConType.UNIQUE,
1034 fields=list(constraint.fields),
1035 issue=f"rename from {old_name}",
1036 drift=IndexDrift(
1037 kind=DriftKind.RENAMED,
1038 table=table,
1039 old_name=old_name,
1040 new_name=constraint.name,
1041 ),
1042 )
1043 )
1044 renamed_missing.add(constraint.name)
1045 renamed_extra.add(old_name)
1046
1047 return statuses, renamed_missing, renamed_extra
1048
1049
1050def _detect_check_renames(
1051 missing: list[CheckConstraint],
1052 extra_names: list[str],
1053 actual_dict: dict[str, ConstraintState],
1054 model: type[Model],
1055 table: str,
1056) -> tuple[list[ConstraintStatus], set[str], set[str]]:
1057 """Match missing and extra check constraints by normalized definition."""
1058 statuses: list[ConstraintStatus] = []
1059 renamed_missing: set[str] = set()
1060 renamed_extra: set[str] = set()
1061
1062 missing_by_def: dict[str, list[CheckConstraint]] = {}
1063 for constraint in missing:
1064 expected_def = _get_expected_check_definition(model, constraint)
1065 norm = normalize_check_definition(expected_def)
1066 missing_by_def.setdefault(norm, []).append(constraint)
1067
1068 extra_by_def: dict[str, list[str]] = {}
1069 for name in extra_names:
1070 if definition := actual_dict[name].definition:
1071 norm = normalize_check_definition(definition)
1072 extra_by_def.setdefault(norm, []).append(name)
1073
1074 for norm_def, m_list in missing_by_def.items():
1075 e_list = extra_by_def.get(norm_def)
1076 if e_list and len(m_list) == 1 and len(e_list) == 1:
1077 constraint = m_list[0]
1078 old_name = e_list[0]
1079 statuses.append(
1080 ConstraintStatus(
1081 name=constraint.name,
1082 constraint_type=ConType.CHECK,
1083 fields=[],
1084 issue=f"rename from {old_name}",
1085 drift=ConstraintDrift(
1086 kind=DriftKind.RENAMED,
1087 table=table,
1088 old_name=old_name,
1089 new_name=constraint.name,
1090 ),
1091 )
1092 )
1093 renamed_missing.add(constraint.name)
1094 renamed_extra.add(old_name)
1095
1096 return statuses, renamed_missing, renamed_extra
1097
1098
1099def _compare_index_only_unique(
1100 model: type[Model],
1101 constraint: UniqueConstraint,
1102 actual_state: ConstraintState,
1103 table: str,
1104) -> tuple[str | None, ConstraintDrift | None]:
1105 """Compare an index-only unique constraint against the DB.
1106
1107 Index-only variants (condition, expressions, opclasses) live as unique
1108 indexes in PostgreSQL, not pg_constraint rows. Their ConstraintState
1109 comes from the pg_index query path with a pg_get_indexdef definition.
1110 """
1111 actual_def = actual_state.definition
1112 if not actual_def:
1113 return None, None
1114
1115 issue = _compare_parsed_index(
1116 model=model,
1117 expressions=constraint.expressions,
1118 fields=list(constraint.fields),
1119 opclasses=list(constraint.opclasses) if constraint.opclasses else [],
1120 condition=constraint.condition,
1121 actual_def=actual_def,
1122 )
1123 if issue:
1124 changed = ConstraintDrift(
1125 kind=DriftKind.CHANGED, table=table, constraint=constraint, model=model
1126 )
1127 return issue, changed
1128
1129 return None, None
1130
1131
1132def _compare_parsed_index(
1133 *,
1134 model: type[Model],
1135 expressions: tuple[Expression | ReplaceableExpression, ...],
1136 fields: list[str],
1137 opclasses: list[str],
1138 condition: Q | None,
1139 actual_def: str,
1140) -> str | None:
1141 """Structured comparison of a model index/constraint against pg_get_indexdef.
1142
1143 Parses the DB definition into components (expression text, columns,
1144 opclasses, WHERE clause) and compares each independently, avoiding
1145 fragile full-SQL normalization between the ORM and PostgreSQL.
1146
1147 Returns an issue string if definitions differ, None if they match.
1148 """
1149 db_parts = _parse_index_definition(actual_def)
1150
1151 if expressions:
1152 expected_expr = normalize_expression(
1153 compile_index_expressions_sql(model, expressions)
1154 )
1155 actual_expr = normalize_expression(db_parts.expression_text)
1156 if actual_expr != expected_expr:
1157 return f"definition differs: DB has {actual_def!r}"
1158 else:
1159 expected_columns = [
1160 model._model_meta.get_forward_field(f).column for f in fields
1161 ]
1162 if db_parts.columns != expected_columns:
1163 return f"columns differ: DB has {db_parts.columns}, model expects {expected_columns}"
1164
1165 if db_parts.opclasses != opclasses:
1166 return f"opclasses differ: DB has {db_parts.opclasses}, model expects {opclasses}"
1167
1168 # Compare WHERE clause
1169 has_condition = condition is not None
1170 if has_condition != db_parts.has_where:
1171 where_desc = "has WHERE" if db_parts.has_where else "no WHERE"
1172 return f"condition differs: DB {where_desc}, model {'has' if has_condition else 'no'} condition"
1173 if has_condition and db_parts.where_clause:
1174 assert condition is not None
1175 expected_where = compile_expression_sql(model, condition)
1176 if normalize_check_definition(
1177 db_parts.where_clause
1178 ) != normalize_check_definition(expected_where):
1179 return f"condition differs: DB has WHERE ({db_parts.where_clause})"
1180
1181 return None
1182
1183
1184@dataclass
1185class _IndexParts:
1186 """Structured components parsed from pg_get_indexdef output."""
1187
1188 columns: list[str]
1189 opclasses: list[str]
1190 has_where: bool
1191 where_clause: str | None
1192 expression_text: str # raw text between the column-list parens
1193
1194
1195def _parse_index_definition(definition: str) -> _IndexParts:
1196 """Parse pg_get_indexdef output into structured components.
1197
1198 Extracts columns, opclasses, and WHERE clause from definitions like:
1199 CREATE UNIQUE INDEX name ON schema.table USING btree (col1, col2 opclass) WHERE (condition)
1200 """
1201 s = definition.lower().replace('"', "")
1202
1203 # Extract WHERE clause (everything after WHERE keyword)
1204 where_clause = None
1205 has_where = False
1206 where_match = re.search(r"\bwhere\s*\(", s)
1207 if where_match:
1208 has_where = True
1209 # Extract the balanced WHERE expression
1210 start = where_match.end() - 1 # include the opening paren
1211 depth = 0
1212 for i in range(start, len(s)):
1213 if s[i] == "(":
1214 depth += 1
1215 elif s[i] == ")":
1216 depth -= 1
1217 if depth == 0:
1218 where_clause = s[start + 1 : i].strip()
1219 s = s[: where_match.start()].strip()
1220 break
1221
1222 # Find the column list: content between parens after USING method
1223 columns: list[str] = []
1224 opclasses: list[str] = []
1225 expression_text = ""
1226 using_match = re.search(r"\busing\s+\w+\s*\(", s)
1227 if using_match:
1228 start = using_match.end()
1229 depth = 1
1230 for i in range(start, len(s)):
1231 if s[i] == "(":
1232 depth += 1
1233 elif s[i] == ")":
1234 depth -= 1
1235 if depth == 0:
1236 expression_text = s[start:i].strip()
1237 for part in expression_text.split(","):
1238 part = part.strip()
1239 # "col opclass" or just "col"
1240 tokens = part.split()
1241 if tokens:
1242 columns.append(tokens[0])
1243 opclasses.append(tokens[1] if len(tokens) > 1 else "")
1244 break
1245
1246 # Strip empty opclasses if none are set
1247 if all(oc == "" for oc in opclasses):
1248 opclasses = []
1249
1250 return _IndexParts(
1251 columns=columns,
1252 opclasses=opclasses,
1253 has_where=has_where,
1254 where_clause=where_clause,
1255 expression_text=expression_text,
1256 )
1257
1258
1259def _get_expected_check_definition(
1260 model: type[Model], constraint: CheckConstraint
1261) -> str:
1262 """Generate the CHECK expression that the model would produce."""
1263 check_sql = compile_expression_sql(model, constraint.check)
1264 return f"CHECK ({check_sql})"
1265
1266
1267def _get_expected_unique_definition(
1268 model: type[Model], constraint: UniqueConstraint
1269) -> str:
1270 """Generate the UNIQUE definition in pg_get_constraintdef format.
1271
1272 PostgreSQL only stores field-based unique constraints (with optional
1273 INCLUDE and DEFERRABLE) in pg_constraint. Expression-based, conditional,
1274 and opclass constraints cannot be attached as constraints — they remain
1275 as indexes only.
1276 """
1277 from ..ddl import build_include_sql, deferrable_sql
1278
1279 columns_sql = ", ".join(
1280 quote_name(model._model_meta.get_forward_field(f).column)
1281 for f in constraint.fields
1282 )
1283 include_sql = build_include_sql(model, constraint.include)
1284 defer_sql = deferrable_sql(constraint.deferrable)
1285 return f"UNIQUE ({columns_sql}){include_sql}{defer_sql}"