1from __future__ import annotations
  2
  3from collections.abc import Sequence
  4from dataclasses import dataclass, field
  5from typing import TYPE_CHECKING
  6
  7from ..db import get_connection
  8from ..registry import models_registry
  9from .analysis import (
 10    ConstraintDrift,
 11    Drift,
 12    DriftKind,
 13    ForeignKeyDrift,
 14    IndexDrift,
 15    NullabilityDrift,
 16    analyze_model,
 17)
 18from .fixes import (
 19    AddConstraintFix,
 20    AddForeignKeyFix,
 21    CreateIndexFix,
 22    DropConstraintFix,
 23    DropIndexFix,
 24    DropNotNullFix,
 25    Fix,
 26    RebuildIndexFix,
 27    RenameConstraintFix,
 28    RenameIndexFix,
 29    SetNotNullFix,
 30    ValidateConstraintFix,
 31)
 32
 33if TYPE_CHECKING:
 34    from ..base import Model
 35    from ..connection import DatabaseConnection
 36    from ..utils import CursorWrapper
 37
 38
 39# Plan items โ€” drift + policy + executable action
 40
 41
 42@dataclass
 43class PlanItem:
 44    """A planned convergence action: drift description + policy + optional fix."""
 45
 46    drift: Drift
 47    fix: Fix | None = None
 48    blocks_sync: bool = True
 49    drop_undeclared: bool = False
 50    guidance: str | None = None
 51
 52    def describe(self) -> str:
 53        if self.fix:
 54            return self.fix.describe()
 55        return self.drift.describe()
 56
 57
 58def _plan_drift(drift: Drift) -> PlanItem:
 59    """Map a semantic drift to a plan item with policy. All policy lives here."""
 60    match drift:
 61        case IndexDrift(kind=DriftKind.MISSING, table=t, index=idx, model=m):
 62            return PlanItem(drift, CreateIndexFix(t, idx, m), blocks_sync=False)
 63        case IndexDrift(kind=DriftKind.INVALID, table=t, index=idx, model=m):
 64            return PlanItem(drift, RebuildIndexFix(t, idx, m), blocks_sync=False)
 65        case IndexDrift(kind=DriftKind.CHANGED, table=t, index=idx, model=m):
 66            return PlanItem(drift, RebuildIndexFix(t, idx, m), blocks_sync=False)
 67        case IndexDrift(kind=DriftKind.RENAMED, table=t, old_name=old, new_name=new):
 68            return PlanItem(drift, RenameIndexFix(t, old, new), blocks_sync=False)
 69        case IndexDrift(kind=DriftKind.UNDECLARED, table=t, name=n):
 70            return PlanItem(
 71                drift, DropIndexFix(t, n), blocks_sync=False, drop_undeclared=True
 72            )
 73        case ConstraintDrift(kind=DriftKind.MISSING, table=t, constraint=c, model=m):
 74            return PlanItem(drift, AddConstraintFix(t, c, m))
 75        case ConstraintDrift(kind=DriftKind.UNVALIDATED, table=t, name=n):
 76            return PlanItem(drift, ValidateConstraintFix(t, n))
 77        case ConstraintDrift(kind=DriftKind.CHANGED):
 78            return PlanItem(
 79                drift,
 80                fix=None,
 81                guidance=(
 82                    "Declare a new constraint under a new name, run sync to add it,"
 83                    " then remove the old one with --drop-undeclared."
 84                ),
 85            )
 86        case ConstraintDrift(
 87            kind=DriftKind.RENAMED, table=t, old_name=old, new_name=new
 88        ):
 89            return PlanItem(drift, RenameConstraintFix(t, old, new), blocks_sync=False)
 90        case ConstraintDrift(kind=DriftKind.UNDECLARED, table=t, name=n):
 91            return PlanItem(drift, DropConstraintFix(t, n), drop_undeclared=True)
 92        case ForeignKeyDrift(
 93            kind=DriftKind.MISSING,
 94            table=t,
 95            name=cn,
 96            column=col,
 97            target_table=tt,
 98            target_column=tc,
 99        ):
100            return PlanItem(drift, AddForeignKeyFix(t, cn, col, tt, tc))
101        case ForeignKeyDrift(kind=DriftKind.UNVALIDATED, table=t, name=n):
102            return PlanItem(drift, ValidateConstraintFix(t, n))
103        case ForeignKeyDrift(kind=DriftKind.UNDECLARED, table=t, name=n):
104            return PlanItem(drift, DropConstraintFix(t, n), drop_undeclared=True)
105        case NullabilityDrift(
106            table=t, column=col, model_allows_null=False, has_null_rows=False
107        ):
108            return PlanItem(drift, SetNotNullFix(t, col))
109        case NullabilityDrift(model_allows_null=False, has_null_rows=True):
110            return PlanItem(
111                drift,
112                fix=None,
113                guidance="Backfill existing NULL rows, then rerun sync.",
114            )
115        case NullabilityDrift(table=t, column=col, model_allows_null=True):
116            return PlanItem(drift, DropNotNullFix(t, col))
117        case _:
118            raise ValueError(f"Unhandled drift: {drift}")
119
120
121def can_auto_fix(drift: Drift) -> bool:
122    """Whether convergence can resolve this drift automatically.
123
124    Used by the schema command for display (fixable vs non-fixable).
125    Delegates to _plan_drift so policy has a single source of truth.
126    """
127    return _plan_drift(drift).fix is not None
128
129
130# Execution results
131
132
133@dataclass
134class FixResult:
135    """Outcome of applying a single plan item."""
136
137    item: PlanItem
138    sql: str | None = None
139    error: Exception | None = None
140
141    @property
142    def ok(self) -> bool:
143        return self.error is None
144
145
146# Convergence plan
147
148
149@dataclass
150class ConvergencePlan:
151    """All planned convergence actions, ready for filtering and execution."""
152
153    items: list[PlanItem]
154
155    def executable(self, *, drop_undeclared: bool = False) -> list[PlanItem]:
156        """Items eligible for execution in this mode, sorted by pass_order."""
157        return [
158            item
159            for item in self.items
160            if item.fix is not None and (drop_undeclared or not item.drop_undeclared)
161        ]
162
163    def has_work(self, *, drop_undeclared: bool = False) -> bool:
164        """Would this mode produce any items to execute?"""
165        return bool(self.executable(drop_undeclared=drop_undeclared))
166
167    @property
168    def blocked(self) -> list[PlanItem]:
169        """Items that cannot be auto-resolved (require staged rollout)."""
170        return [item for item in self.items if item.fix is None]
171
172    @property
173    def blocking_cleanup(self) -> list[PlanItem]:
174        """Undeclared-object drops that block sync success."""
175        return [
176            item for item in self.items if item.drop_undeclared and item.blocks_sync
177        ]
178
179    @property
180    def optional_cleanup(self) -> list[PlanItem]:
181        """Undeclared-object drops safe to defer."""
182        return [
183            item for item in self.items if item.drop_undeclared and not item.blocks_sync
184        ]
185
186
187# Convergence result
188
189
190@dataclass
191class ConvergenceResult:
192    """Outcome of executing convergence plan items."""
193
194    results: list[FixResult] = field(default_factory=list)
195
196    @property
197    def applied(self) -> int:
198        return sum(1 for r in self.results if r.ok)
199
200    @property
201    def failed(self) -> int:
202        return sum(1 for r in self.results if not r.ok)
203
204    @property
205    def ok(self) -> bool:
206        return all(r.ok for r in self.results)
207
208    @property
209    def ok_for_sync(self) -> bool:
210        """True if no sync-blocking items failed."""
211        return all(r.ok for r in self.results if r.item.blocks_sync)
212
213    @property
214    def blocking_failures(self) -> list[FixResult]:
215        return [r for r in self.results if not r.ok and r.item.blocks_sync]
216
217    @property
218    def non_blocking_failures(self) -> list[FixResult]:
219        return [r for r in self.results if not r.ok and not r.item.blocks_sync]
220
221    @property
222    def summary(self) -> str:
223        parts = []
224        if self.applied:
225            parts.append(f"{self.applied} applied")
226        if self.failed:
227            parts.append(f"{self.failed} failed")
228        return ", ".join(parts) + "."
229
230
231# Plan construction
232
233
234def plan_convergence() -> ConvergencePlan:
235    """Scan all models against the database and produce a convergence plan."""
236    conn = get_connection()
237    items: list[PlanItem] = []
238
239    with conn.cursor() as cursor:
240        for model in models_registry.get_models():
241            for drift in analyze_model(conn, cursor, model).drifts:
242                items.append(_plan_drift(drift))
243
244    items.sort(key=lambda item: item.fix.pass_order if item.fix else float("inf"))
245    return ConvergencePlan(items=items)
246
247
248def plan_model_convergence(
249    conn: DatabaseConnection, cursor: CursorWrapper, model: type[Model]
250) -> ConvergencePlan:
251    """Produce a convergence plan for a single model."""
252    items = [_plan_drift(d) for d in analyze_model(conn, cursor, model).drifts]
253    items.sort(key=lambda item: item.fix.pass_order if item.fix else float("inf"))
254    return ConvergencePlan(items=items)
255
256
257def execute_plan(items: Sequence[PlanItem]) -> ConvergenceResult:
258    """Apply plan items independently, collecting results.
259
260    Each item is applied and committed independently so partial
261    failures don't block subsequent items.
262    """
263    result = ConvergenceResult()
264    for item in items:
265        assert item.fix is not None
266        try:
267            sql = item.fix.apply()
268            result.results.append(FixResult(item=item, sql=sql))
269        except Exception as e:
270            result.results.append(FixResult(item=item, error=e))
271    return result