1from __future__ import annotations
2
3from collections.abc import Sequence
4from dataclasses import dataclass, field
5from typing import TYPE_CHECKING
6
7from ..db import get_connection
8from ..registry import models_registry
9from .analysis import (
10 ConstraintDrift,
11 Drift,
12 DriftKind,
13 ForeignKeyDrift,
14 IndexDrift,
15 NullabilityDrift,
16 analyze_model,
17)
18from .fixes import (
19 AddConstraintFix,
20 AddForeignKeyFix,
21 CreateIndexFix,
22 DropConstraintFix,
23 DropIndexFix,
24 DropNotNullFix,
25 Fix,
26 RebuildIndexFix,
27 RenameConstraintFix,
28 RenameIndexFix,
29 SetNotNullFix,
30 ValidateConstraintFix,
31)
32
33if TYPE_CHECKING:
34 from ..base import Model
35 from ..connection import DatabaseConnection
36 from ..utils import CursorWrapper
37
38
39# Plan items โ drift + policy + executable action
40
41
42@dataclass
43class PlanItem:
44 """A planned convergence action: drift description + policy + optional fix."""
45
46 drift: Drift
47 fix: Fix | None = None
48 blocks_sync: bool = True
49 drop_undeclared: bool = False
50 guidance: str | None = None
51
52 def describe(self) -> str:
53 if self.fix:
54 return self.fix.describe()
55 return self.drift.describe()
56
57
58def _plan_drift(drift: Drift) -> PlanItem:
59 """Map a semantic drift to a plan item with policy. All policy lives here."""
60 match drift:
61 case IndexDrift(kind=DriftKind.MISSING, table=t, index=idx, model=m):
62 return PlanItem(drift, CreateIndexFix(t, idx, m), blocks_sync=False)
63 case IndexDrift(kind=DriftKind.INVALID, table=t, index=idx, model=m):
64 return PlanItem(drift, RebuildIndexFix(t, idx, m), blocks_sync=False)
65 case IndexDrift(kind=DriftKind.CHANGED, table=t, index=idx, model=m):
66 return PlanItem(drift, RebuildIndexFix(t, idx, m), blocks_sync=False)
67 case IndexDrift(kind=DriftKind.RENAMED, table=t, old_name=old, new_name=new):
68 return PlanItem(drift, RenameIndexFix(t, old, new), blocks_sync=False)
69 case IndexDrift(kind=DriftKind.UNDECLARED, table=t, name=n):
70 return PlanItem(
71 drift, DropIndexFix(t, n), blocks_sync=False, drop_undeclared=True
72 )
73 case ConstraintDrift(kind=DriftKind.MISSING, table=t, constraint=c, model=m):
74 return PlanItem(drift, AddConstraintFix(t, c, m))
75 case ConstraintDrift(kind=DriftKind.UNVALIDATED, table=t, name=n):
76 return PlanItem(drift, ValidateConstraintFix(t, n))
77 case ConstraintDrift(kind=DriftKind.CHANGED):
78 return PlanItem(
79 drift,
80 fix=None,
81 guidance=(
82 "Declare a new constraint under a new name, run sync to add it,"
83 " then remove the old one with --drop-undeclared."
84 ),
85 )
86 case ConstraintDrift(
87 kind=DriftKind.RENAMED, table=t, old_name=old, new_name=new
88 ):
89 return PlanItem(drift, RenameConstraintFix(t, old, new), blocks_sync=False)
90 case ConstraintDrift(kind=DriftKind.UNDECLARED, table=t, name=n):
91 return PlanItem(drift, DropConstraintFix(t, n), drop_undeclared=True)
92 case ForeignKeyDrift(
93 kind=DriftKind.MISSING,
94 table=t,
95 name=cn,
96 column=col,
97 target_table=tt,
98 target_column=tc,
99 ):
100 return PlanItem(drift, AddForeignKeyFix(t, cn, col, tt, tc))
101 case ForeignKeyDrift(kind=DriftKind.UNVALIDATED, table=t, name=n):
102 return PlanItem(drift, ValidateConstraintFix(t, n))
103 case ForeignKeyDrift(kind=DriftKind.UNDECLARED, table=t, name=n):
104 return PlanItem(drift, DropConstraintFix(t, n), drop_undeclared=True)
105 case NullabilityDrift(
106 table=t, column=col, model_allows_null=False, has_null_rows=False
107 ):
108 return PlanItem(drift, SetNotNullFix(t, col))
109 case NullabilityDrift(model_allows_null=False, has_null_rows=True):
110 return PlanItem(
111 drift,
112 fix=None,
113 guidance="Backfill existing NULL rows, then rerun sync.",
114 )
115 case NullabilityDrift(table=t, column=col, model_allows_null=True):
116 return PlanItem(drift, DropNotNullFix(t, col))
117 case _:
118 raise ValueError(f"Unhandled drift: {drift}")
119
120
121def can_auto_fix(drift: Drift) -> bool:
122 """Whether convergence can resolve this drift automatically.
123
124 Used by the schema command for display (fixable vs non-fixable).
125 Delegates to _plan_drift so policy has a single source of truth.
126 """
127 return _plan_drift(drift).fix is not None
128
129
130# Execution results
131
132
133@dataclass
134class FixResult:
135 """Outcome of applying a single plan item."""
136
137 item: PlanItem
138 sql: str | None = None
139 error: Exception | None = None
140
141 @property
142 def ok(self) -> bool:
143 return self.error is None
144
145
146# Convergence plan
147
148
149@dataclass
150class ConvergencePlan:
151 """All planned convergence actions, ready for filtering and execution."""
152
153 items: list[PlanItem]
154
155 def executable(self, *, drop_undeclared: bool = False) -> list[PlanItem]:
156 """Items eligible for execution in this mode, sorted by pass_order."""
157 return [
158 item
159 for item in self.items
160 if item.fix is not None and (drop_undeclared or not item.drop_undeclared)
161 ]
162
163 def has_work(self, *, drop_undeclared: bool = False) -> bool:
164 """Would this mode produce any items to execute?"""
165 return bool(self.executable(drop_undeclared=drop_undeclared))
166
167 @property
168 def blocked(self) -> list[PlanItem]:
169 """Items that cannot be auto-resolved (require staged rollout)."""
170 return [item for item in self.items if item.fix is None]
171
172 @property
173 def blocking_cleanup(self) -> list[PlanItem]:
174 """Undeclared-object drops that block sync success."""
175 return [
176 item for item in self.items if item.drop_undeclared and item.blocks_sync
177 ]
178
179 @property
180 def optional_cleanup(self) -> list[PlanItem]:
181 """Undeclared-object drops safe to defer."""
182 return [
183 item for item in self.items if item.drop_undeclared and not item.blocks_sync
184 ]
185
186
187# Convergence result
188
189
190@dataclass
191class ConvergenceResult:
192 """Outcome of executing convergence plan items."""
193
194 results: list[FixResult] = field(default_factory=list)
195
196 @property
197 def applied(self) -> int:
198 return sum(1 for r in self.results if r.ok)
199
200 @property
201 def failed(self) -> int:
202 return sum(1 for r in self.results if not r.ok)
203
204 @property
205 def ok(self) -> bool:
206 return all(r.ok for r in self.results)
207
208 @property
209 def ok_for_sync(self) -> bool:
210 """True if no sync-blocking items failed."""
211 return all(r.ok for r in self.results if r.item.blocks_sync)
212
213 @property
214 def blocking_failures(self) -> list[FixResult]:
215 return [r for r in self.results if not r.ok and r.item.blocks_sync]
216
217 @property
218 def non_blocking_failures(self) -> list[FixResult]:
219 return [r for r in self.results if not r.ok and not r.item.blocks_sync]
220
221 @property
222 def summary(self) -> str:
223 parts = []
224 if self.applied:
225 parts.append(f"{self.applied} applied")
226 if self.failed:
227 parts.append(f"{self.failed} failed")
228 return ", ".join(parts) + "."
229
230
231# Plan construction
232
233
234def plan_convergence() -> ConvergencePlan:
235 """Scan all models against the database and produce a convergence plan."""
236 conn = get_connection()
237 items: list[PlanItem] = []
238
239 with conn.cursor() as cursor:
240 for model in models_registry.get_models():
241 for drift in analyze_model(conn, cursor, model).drifts:
242 items.append(_plan_drift(drift))
243
244 items.sort(key=lambda item: item.fix.pass_order if item.fix else float("inf"))
245 return ConvergencePlan(items=items)
246
247
248def plan_model_convergence(
249 conn: DatabaseConnection, cursor: CursorWrapper, model: type[Model]
250) -> ConvergencePlan:
251 """Produce a convergence plan for a single model."""
252 items = [_plan_drift(d) for d in analyze_model(conn, cursor, model).drifts]
253 items.sort(key=lambda item: item.fix.pass_order if item.fix else float("inf"))
254 return ConvergencePlan(items=items)
255
256
257def execute_plan(items: Sequence[PlanItem]) -> ConvergenceResult:
258 """Apply plan items independently, collecting results.
259
260 Each item is applied and committed independently so partial
261 failures don't block subsequent items.
262 """
263 result = ConvergenceResult()
264 for item in items:
265 assert item.fix is not None
266 try:
267 sql = item.fix.apply()
268 result.results.append(FixResult(item=item, sql=sql))
269 except Exception as e:
270 result.results.append(FixResult(item=item, error=e))
271 return result