1from __future__ import annotations
2
3from collections.abc import Sequence
4from dataclasses import dataclass, field
5from typing import TYPE_CHECKING
6
7from ..db import get_connection
8from ..registry import models_registry
9from .analysis import (
10 ColumnDefaultDrift,
11 ConstraintDrift,
12 Drift,
13 DriftKind,
14 ForeignKeyDrift,
15 IndexDrift,
16 NullabilityDrift,
17 StorageParameterDrift,
18 analyze_model,
19)
20from .fixes import (
21 AddConstraintFix,
22 AddForeignKeyFix,
23 CreateIndexFix,
24 DropColumnDefaultFix,
25 DropConstraintFix,
26 DropIndexFix,
27 DropNotNullFix,
28 Fix,
29 RebuildIndexFix,
30 RenameConstraintFix,
31 RenameIndexFix,
32 ReplaceForeignKeyFix,
33 ResetStorageParameterFix,
34 SetColumnDefaultFix,
35 SetNotNullFix,
36 SetStorageParameterFix,
37 ValidateConstraintFix,
38)
39
40if TYPE_CHECKING:
41 from ..base import Model
42 from ..connection import DatabaseConnection
43 from ..utils import CursorWrapper
44
45
46# Plan items — drift + policy + executable action
47
48
49@dataclass
50class PlanItem:
51 """A planned convergence action: drift description + policy + optional fix."""
52
53 drift: Drift
54 fix: Fix | None = None
55 blocks_sync: bool = True
56 guidance: str | None = None
57
58 def describe(self) -> str:
59 if self.fix:
60 return self.fix.describe()
61 return self.drift.describe()
62
63
64def _plan_drift(drift: Drift) -> PlanItem:
65 """Map a semantic drift to a plan item with policy. All policy lives here."""
66 match drift:
67 case IndexDrift(kind=DriftKind.MISSING, table=t, index=idx, model=m):
68 return PlanItem(drift, CreateIndexFix(t, idx, m), blocks_sync=False)
69 case IndexDrift(kind=DriftKind.INVALID, table=t, index=idx, model=m):
70 return PlanItem(drift, RebuildIndexFix(t, idx, m), blocks_sync=False)
71 case IndexDrift(kind=DriftKind.CHANGED, table=t, index=idx, model=m):
72 return PlanItem(drift, RebuildIndexFix(t, idx, m), blocks_sync=False)
73 case IndexDrift(kind=DriftKind.RENAMED, table=t, old_name=old, new_name=new):
74 return PlanItem(drift, RenameIndexFix(t, old, new), blocks_sync=False)
75 case IndexDrift(kind=DriftKind.UNDECLARED, table=t, name=n):
76 return PlanItem(drift, DropIndexFix(t, n), blocks_sync=False)
77 case ConstraintDrift(kind=DriftKind.MISSING, table=t, constraint=c, model=m):
78 return PlanItem(drift, AddConstraintFix(t, c, m))
79 case ConstraintDrift(kind=DriftKind.UNVALIDATED, table=t, name=n):
80 return PlanItem(drift, ValidateConstraintFix(t, n))
81 case ConstraintDrift(kind=DriftKind.CHANGED):
82 return PlanItem(
83 drift,
84 fix=None,
85 guidance=(
86 "Declare a new constraint under a new name, run sync to add it,"
87 " then remove the old one."
88 ),
89 )
90 case ConstraintDrift(
91 kind=DriftKind.RENAMED, table=t, old_name=old, new_name=new
92 ):
93 return PlanItem(drift, RenameConstraintFix(t, old, new), blocks_sync=False)
94 case ConstraintDrift(kind=DriftKind.UNDECLARED, table=t, name=n):
95 return PlanItem(drift, DropConstraintFix(t, n))
96 case ForeignKeyDrift(
97 kind=DriftKind.MISSING,
98 table=t,
99 name=cn,
100 column=col,
101 target_table=tt,
102 target_column=tc,
103 on_delete_clause=od,
104 ):
105 return PlanItem(drift, AddForeignKeyFix(t, cn, col, tt, tc, od))
106 case ForeignKeyDrift(
107 kind=DriftKind.CHANGED,
108 table=t,
109 name=cn,
110 column=col,
111 target_table=tt,
112 target_column=tc,
113 on_delete_clause=od,
114 ):
115 assert cn is not None
116 assert col is not None
117 assert tt is not None
118 assert tc is not None
119 return PlanItem(drift, ReplaceForeignKeyFix(t, cn, col, tt, tc, od))
120 case ForeignKeyDrift(kind=DriftKind.UNVALIDATED, table=t, name=n):
121 return PlanItem(drift, ValidateConstraintFix(t, n))
122 case ForeignKeyDrift(kind=DriftKind.UNDECLARED, table=t, name=n):
123 return PlanItem(drift, DropConstraintFix(t, n))
124 case NullabilityDrift(
125 table=t, column=col, model_allows_null=False, has_null_rows=False
126 ):
127 return PlanItem(drift, SetNotNullFix(t, col))
128 case NullabilityDrift(model_allows_null=False, has_null_rows=True):
129 return PlanItem(
130 drift,
131 fix=None,
132 guidance="Backfill existing NULL rows, then rerun sync.",
133 )
134 case NullabilityDrift(table=t, column=col, model_allows_null=True):
135 return PlanItem(drift, DropNotNullFix(t, col))
136 case ColumnDefaultDrift(
137 kind=DriftKind.MISSING | DriftKind.CHANGED,
138 table=t,
139 column=col,
140 model_default_sql=default_sql,
141 ):
142 assert default_sql is not None # MISSING/CHANGED always carry model SQL
143 return PlanItem(drift, SetColumnDefaultFix(t, col, default_sql))
144 case ColumnDefaultDrift(kind=DriftKind.UNDECLARED, table=t, column=col):
145 return PlanItem(drift, DropColumnDefaultFix(t, col))
146 case StorageParameterDrift(
147 kind=DriftKind.MISSING | DriftKind.CHANGED,
148 table=t,
149 key=k,
150 declared_value=v,
151 ):
152 assert v is not None
153 return PlanItem(drift, SetStorageParameterFix(t, k, v))
154 case StorageParameterDrift(kind=DriftKind.UNDECLARED, table=t, key=k):
155 return PlanItem(drift, ResetStorageParameterFix(t, k))
156 case _:
157 raise ValueError(f"Unhandled drift: {drift}")
158
159
160def can_auto_fix(drift: Drift) -> bool:
161 """Whether convergence can resolve this drift automatically.
162
163 Used by the schema command for display (fixable vs non-fixable).
164 Delegates to _plan_drift so policy has a single source of truth.
165 """
166 return _plan_drift(drift).fix is not None
167
168
169# Execution results
170
171
172@dataclass
173class FixResult:
174 """Outcome of applying a single plan item."""
175
176 item: PlanItem
177 sql: str | None = None
178 error: Exception | None = None
179
180 @property
181 def ok(self) -> bool:
182 return self.error is None
183
184
185# Convergence plan
186
187
188@dataclass
189class ConvergencePlan:
190 """All planned convergence actions, ready for filtering and execution."""
191
192 items: list[PlanItem]
193
194 def executable(self) -> list[PlanItem]:
195 """Items eligible for execution, sorted by pass_order."""
196 return [item for item in self.items if item.fix is not None]
197
198 def has_work(self) -> bool:
199 """Would execution produce any items?"""
200 return bool(self.executable())
201
202 @property
203 def blocked(self) -> list[PlanItem]:
204 """Items that cannot be auto-resolved (require staged rollout)."""
205 return [item for item in self.items if item.fix is None]
206
207
208# Convergence result
209
210
211@dataclass
212class ConvergenceResult:
213 """Outcome of executing convergence plan items."""
214
215 results: list[FixResult] = field(default_factory=list)
216
217 @property
218 def applied(self) -> int:
219 return sum(1 for r in self.results if r.ok)
220
221 @property
222 def failed(self) -> int:
223 return sum(1 for r in self.results if not r.ok)
224
225 @property
226 def ok(self) -> bool:
227 return all(r.ok for r in self.results)
228
229 @property
230 def ok_for_sync(self) -> bool:
231 """True if no sync-blocking items failed."""
232 return all(r.ok for r in self.results if r.item.blocks_sync)
233
234 @property
235 def blocking_failures(self) -> list[FixResult]:
236 return [r for r in self.results if not r.ok and r.item.blocks_sync]
237
238 @property
239 def non_blocking_failures(self) -> list[FixResult]:
240 return [r for r in self.results if not r.ok and not r.item.blocks_sync]
241
242 @property
243 def summary(self) -> str:
244 parts = []
245 if self.applied:
246 parts.append(f"{self.applied} applied")
247 if self.failed:
248 parts.append(f"{self.failed} failed")
249 return ", ".join(parts) + "."
250
251
252# Plan construction
253
254
255def plan_convergence() -> ConvergencePlan:
256 """Scan all models against the database and produce a convergence plan."""
257 conn = get_connection()
258 items: list[PlanItem] = []
259
260 with conn.cursor() as cursor:
261 for model in models_registry.get_models():
262 for drift in analyze_model(conn, cursor, model).drifts:
263 items.append(_plan_drift(drift))
264
265 items.sort(key=lambda item: item.fix.pass_order if item.fix else float("inf"))
266 return ConvergencePlan(items=items)
267
268
269def plan_model_convergence(
270 conn: DatabaseConnection, cursor: CursorWrapper, model: type[Model]
271) -> ConvergencePlan:
272 """Produce a convergence plan for a single model."""
273 items = [_plan_drift(d) for d in analyze_model(conn, cursor, model).drifts]
274 items.sort(key=lambda item: item.fix.pass_order if item.fix else float("inf"))
275 return ConvergencePlan(items=items)
276
277
278def execute_plan(items: Sequence[PlanItem]) -> ConvergenceResult:
279 """Apply plan items independently, collecting results.
280
281 Each item is applied and committed independently so partial
282 failures don't block subsequent items.
283 """
284 result = ConvergenceResult()
285 for item in items:
286 assert item.fix is not None
287 try:
288 sql = item.fix.apply()
289 result.results.append(FixResult(item=item, sql=sql))
290 except Exception as e:
291 result.results.append(FixResult(item=item, error=e))
292 return result