v0.150.0
  1from __future__ import annotations
  2
  3from collections.abc import Sequence
  4from dataclasses import dataclass, field
  5from typing import TYPE_CHECKING
  6
  7from ..db import get_connection
  8from ..registry import models_registry
  9from .analysis import (
 10    ColumnDefaultDrift,
 11    ConstraintDrift,
 12    Drift,
 13    DriftKind,
 14    ForeignKeyDrift,
 15    IndexDrift,
 16    NullabilityDrift,
 17    StorageParameterDrift,
 18    analyze_model,
 19)
 20from .fixes import (
 21    AddConstraintFix,
 22    AddForeignKeyFix,
 23    CreateIndexFix,
 24    DropColumnDefaultFix,
 25    DropConstraintFix,
 26    DropIndexFix,
 27    DropNotNullFix,
 28    Fix,
 29    RebuildIndexFix,
 30    RenameConstraintFix,
 31    RenameIndexFix,
 32    ReplaceForeignKeyFix,
 33    ResetStorageParameterFix,
 34    SetColumnDefaultFix,
 35    SetNotNullFix,
 36    SetStorageParameterFix,
 37    ValidateConstraintFix,
 38)
 39
 40if TYPE_CHECKING:
 41    from ..base import Model
 42    from ..connection import DatabaseConnection
 43    from ..utils import CursorWrapper
 44
 45
 46# Plan items — drift + policy + executable action
 47
 48
 49@dataclass
 50class PlanItem:
 51    """A planned convergence action: drift description + policy + optional fix."""
 52
 53    drift: Drift
 54    fix: Fix | None = None
 55    blocks_sync: bool = True
 56    guidance: str | None = None
 57
 58    def describe(self) -> str:
 59        if self.fix:
 60            return self.fix.describe()
 61        return self.drift.describe()
 62
 63
 64def _plan_drift(drift: Drift) -> PlanItem:
 65    """Map a semantic drift to a plan item with policy. All policy lives here."""
 66    match drift:
 67        case IndexDrift(kind=DriftKind.MISSING, table=t, index=idx, model=m):
 68            return PlanItem(drift, CreateIndexFix(t, idx, m), blocks_sync=False)
 69        case IndexDrift(kind=DriftKind.INVALID, table=t, index=idx, model=m):
 70            return PlanItem(drift, RebuildIndexFix(t, idx, m), blocks_sync=False)
 71        case IndexDrift(kind=DriftKind.CHANGED, table=t, index=idx, model=m):
 72            return PlanItem(drift, RebuildIndexFix(t, idx, m), blocks_sync=False)
 73        case IndexDrift(kind=DriftKind.RENAMED, table=t, old_name=old, new_name=new):
 74            return PlanItem(drift, RenameIndexFix(t, old, new), blocks_sync=False)
 75        case IndexDrift(kind=DriftKind.UNDECLARED, table=t, name=n):
 76            return PlanItem(drift, DropIndexFix(t, n), blocks_sync=False)
 77        case ConstraintDrift(kind=DriftKind.MISSING, table=t, constraint=c, model=m):
 78            return PlanItem(drift, AddConstraintFix(t, c, m))
 79        case ConstraintDrift(kind=DriftKind.UNVALIDATED, table=t, name=n):
 80            return PlanItem(drift, ValidateConstraintFix(t, n))
 81        case ConstraintDrift(kind=DriftKind.CHANGED):
 82            return PlanItem(
 83                drift,
 84                fix=None,
 85                guidance=(
 86                    "Declare a new constraint under a new name, run sync to add it,"
 87                    " then remove the old one."
 88                ),
 89            )
 90        case ConstraintDrift(
 91            kind=DriftKind.RENAMED, table=t, old_name=old, new_name=new
 92        ):
 93            return PlanItem(drift, RenameConstraintFix(t, old, new), blocks_sync=False)
 94        case ConstraintDrift(kind=DriftKind.UNDECLARED, table=t, name=n):
 95            return PlanItem(drift, DropConstraintFix(t, n))
 96        case ForeignKeyDrift(
 97            kind=DriftKind.MISSING,
 98            table=t,
 99            name=cn,
100            column=col,
101            target_table=tt,
102            target_column=tc,
103            on_delete_clause=od,
104        ):
105            return PlanItem(drift, AddForeignKeyFix(t, cn, col, tt, tc, od))
106        case ForeignKeyDrift(
107            kind=DriftKind.CHANGED,
108            table=t,
109            name=cn,
110            column=col,
111            target_table=tt,
112            target_column=tc,
113            on_delete_clause=od,
114        ):
115            assert cn is not None
116            assert col is not None
117            assert tt is not None
118            assert tc is not None
119            return PlanItem(drift, ReplaceForeignKeyFix(t, cn, col, tt, tc, od))
120        case ForeignKeyDrift(kind=DriftKind.UNVALIDATED, table=t, name=n):
121            return PlanItem(drift, ValidateConstraintFix(t, n))
122        case ForeignKeyDrift(kind=DriftKind.UNDECLARED, table=t, name=n):
123            return PlanItem(drift, DropConstraintFix(t, n))
124        case NullabilityDrift(
125            table=t, column=col, model_allows_null=False, has_null_rows=False
126        ):
127            return PlanItem(drift, SetNotNullFix(t, col))
128        case NullabilityDrift(model_allows_null=False, has_null_rows=True):
129            return PlanItem(
130                drift,
131                fix=None,
132                guidance="Backfill existing NULL rows, then rerun sync.",
133            )
134        case NullabilityDrift(table=t, column=col, model_allows_null=True):
135            return PlanItem(drift, DropNotNullFix(t, col))
136        case ColumnDefaultDrift(
137            kind=DriftKind.MISSING | DriftKind.CHANGED,
138            table=t,
139            column=col,
140            model_default_sql=default_sql,
141        ):
142            assert default_sql is not None  # MISSING/CHANGED always carry model SQL
143            return PlanItem(drift, SetColumnDefaultFix(t, col, default_sql))
144        case ColumnDefaultDrift(kind=DriftKind.UNDECLARED, table=t, column=col):
145            return PlanItem(drift, DropColumnDefaultFix(t, col))
146        case StorageParameterDrift(
147            kind=DriftKind.MISSING | DriftKind.CHANGED,
148            table=t,
149            key=k,
150            declared_value=v,
151        ):
152            assert v is not None
153            return PlanItem(drift, SetStorageParameterFix(t, k, v))
154        case StorageParameterDrift(kind=DriftKind.UNDECLARED, table=t, key=k):
155            return PlanItem(drift, ResetStorageParameterFix(t, k))
156        case _:
157            raise ValueError(f"Unhandled drift: {drift}")
158
159
160def can_auto_fix(drift: Drift) -> bool:
161    """Whether convergence can resolve this drift automatically.
162
163    Used by the schema command for display (fixable vs non-fixable).
164    Delegates to _plan_drift so policy has a single source of truth.
165    """
166    return _plan_drift(drift).fix is not None
167
168
169# Execution results
170
171
172@dataclass
173class FixResult:
174    """Outcome of applying a single plan item."""
175
176    item: PlanItem
177    sql: str | None = None
178    error: Exception | None = None
179
180    @property
181    def ok(self) -> bool:
182        return self.error is None
183
184
185# Convergence plan
186
187
188@dataclass
189class ConvergencePlan:
190    """All planned convergence actions, ready for filtering and execution."""
191
192    items: list[PlanItem]
193
194    def executable(self) -> list[PlanItem]:
195        """Items eligible for execution, sorted by pass_order."""
196        return [item for item in self.items if item.fix is not None]
197
198    def has_work(self) -> bool:
199        """Would execution produce any items?"""
200        return bool(self.executable())
201
202    @property
203    def blocked(self) -> list[PlanItem]:
204        """Items that cannot be auto-resolved (require staged rollout)."""
205        return [item for item in self.items if item.fix is None]
206
207
208# Convergence result
209
210
211@dataclass
212class ConvergenceResult:
213    """Outcome of executing convergence plan items."""
214
215    results: list[FixResult] = field(default_factory=list)
216
217    @property
218    def applied(self) -> int:
219        return sum(1 for r in self.results if r.ok)
220
221    @property
222    def failed(self) -> int:
223        return sum(1 for r in self.results if not r.ok)
224
225    @property
226    def ok(self) -> bool:
227        return all(r.ok for r in self.results)
228
229    @property
230    def ok_for_sync(self) -> bool:
231        """True if no sync-blocking items failed."""
232        return all(r.ok for r in self.results if r.item.blocks_sync)
233
234    @property
235    def blocking_failures(self) -> list[FixResult]:
236        return [r for r in self.results if not r.ok and r.item.blocks_sync]
237
238    @property
239    def non_blocking_failures(self) -> list[FixResult]:
240        return [r for r in self.results if not r.ok and not r.item.blocks_sync]
241
242    @property
243    def summary(self) -> str:
244        parts = []
245        if self.applied:
246            parts.append(f"{self.applied} applied")
247        if self.failed:
248            parts.append(f"{self.failed} failed")
249        return ", ".join(parts) + "."
250
251
252# Plan construction
253
254
255def plan_convergence() -> ConvergencePlan:
256    """Scan all models against the database and produce a convergence plan."""
257    conn = get_connection()
258    items: list[PlanItem] = []
259
260    with conn.cursor() as cursor:
261        for model in models_registry.get_models():
262            for drift in analyze_model(conn, cursor, model).drifts:
263                items.append(_plan_drift(drift))
264
265    items.sort(key=lambda item: item.fix.pass_order if item.fix else float("inf"))
266    return ConvergencePlan(items=items)
267
268
269def plan_model_convergence(
270    conn: DatabaseConnection, cursor: CursorWrapper, model: type[Model]
271) -> ConvergencePlan:
272    """Produce a convergence plan for a single model."""
273    items = [_plan_drift(d) for d in analyze_model(conn, cursor, model).drifts]
274    items.sort(key=lambda item: item.fix.pass_order if item.fix else float("inf"))
275    return ConvergencePlan(items=items)
276
277
278def execute_plan(items: Sequence[PlanItem]) -> ConvergenceResult:
279    """Apply plan items independently, collecting results.
280
281    Each item is applied and committed independently so partial
282    failures don't block subsequent items.
283    """
284    result = ConvergenceResult()
285    for item in items:
286        assert item.fix is not None
287        try:
288            sql = item.fix.apply()
289            result.results.append(FixResult(item=item, sql=sql))
290        except Exception as e:
291            result.results.append(FixResult(item=item, error=e))
292    return result