1from __future__ import annotations
  2
  3from functools import total_ordering
  4from typing import TYPE_CHECKING, Any, cast
  5
  6from plain.postgres.migrations.state import ProjectState
  7
  8from .exceptions import CircularDependencyError, NodeNotFoundError
  9
 10if TYPE_CHECKING:
 11    from plain.postgres.migrations.migration import Migration
 12
 13
 14@total_ordering
 15class Node:
 16    """
 17    A single node in the migration graph. Contains direct links to adjacent
 18    nodes in either direction.
 19    """
 20
 21    def __init__(self, key: tuple[str, str]):
 22        self.key = key
 23        self.children: set[Node] = set()
 24        self.parents: set[Node] = set()
 25
 26    def __eq__(self, other: object) -> bool:
 27        if isinstance(other, Node):
 28            return self.key == other.key
 29        return self.key == other
 30
 31    def __lt__(self, other: object) -> bool:
 32        if isinstance(other, Node):
 33            return self.key < other.key
 34        if isinstance(other, tuple):
 35            return self.key < cast(tuple[str, str], other)
 36        return NotImplemented
 37
 38    def __hash__(self) -> int:
 39        return hash(self.key)
 40
 41    def __getitem__(self, item: int) -> str:
 42        return self.key[item]
 43
 44    def __str__(self) -> str:
 45        return str(self.key)
 46
 47    def __repr__(self) -> str:
 48        return f"<{self.__class__.__name__}: ({self.key[0]!r}, {self.key[1]!r})>"
 49
 50    def add_child(self, child: Node) -> None:
 51        self.children.add(child)
 52
 53    def add_parent(self, parent: Node) -> None:
 54        self.parents.add(parent)
 55
 56
 57class DummyNode(Node):
 58    """
 59    A node that doesn't correspond to a migration file on disk.
 60    (A squashed migration that was removed, for example.)
 61
 62    After the migration graph is processed, all dummy nodes should be removed.
 63    If there are any left, a nonexistent dependency error is raised.
 64    """
 65
 66    def __init__(
 67        self,
 68        key: tuple[str, str],
 69        origin: Migration | tuple[str, str] | None,
 70        error_message: str,
 71    ):
 72        super().__init__(key)
 73        self.origin = origin
 74        self.error_message = error_message
 75
 76    def raise_error(self) -> None:
 77        raise NodeNotFoundError(self.error_message, self.key, origin=self.origin)
 78
 79
 80class MigrationGraph:
 81    """
 82    Represent the digraph of all migrations in a project.
 83
 84    Each migration is a node, and each dependency is an edge. There are
 85    no implicit dependencies between numbered migrations - the numbering is
 86    merely a convention to aid file listing. Every new numbered migration
 87    has a declared dependency to the previous number, meaning that VCS
 88    branch merges can be detected and resolved.
 89
 90    Migrations files can be marked as replacing another set of migrations -
 91    this is to support the "squash" feature. The graph handler isn't responsible
 92    for these; instead, the code to load them in here should examine the
 93    migration files and if the replaced migrations are all either unapplied
 94    or not present, it should ignore the replaced ones, load in just the
 95    replacing migration, and repoint any dependencies that pointed to the
 96    replaced migrations to point to the replacing one.
 97
 98    A node should be a tuple: (app_path, migration_name). The tree special-cases
 99    things within an app - namely, root nodes and leaf nodes ignore dependencies
100    to other packages.
101    """
102
103    def __init__(self):
104        self.node_map: dict[tuple[str, str], Node] = {}
105        self.nodes: dict[tuple[str, str], Migration | None] = {}
106
107    def add_node(self, key: tuple[str, str], migration: Migration) -> None:
108        assert key not in self.node_map
109        node = Node(key)
110        self.node_map[key] = node
111        self.nodes[key] = migration
112
113    def add_dummy_node(
114        self,
115        key: tuple[str, str],
116        origin: Migration | tuple[str, str] | None,
117        error_message: str,
118    ) -> None:
119        node = DummyNode(key, origin, error_message)
120        self.node_map[key] = node
121        self.nodes[key] = None
122
123    def add_dependency(
124        self,
125        migration: Migration | tuple[str, str] | None,
126        child: tuple[str, str],
127        parent: tuple[str, str],
128        skip_validation: bool = False,
129    ) -> None:
130        """
131        This may create dummy nodes if they don't yet exist. If
132        `skip_validation=True`, validate_consistency() should be called
133        afterward.
134        """
135        if child not in self.nodes:
136            error_message = (
137                f"Migration {migration} dependencies reference nonexistent"
138                f" child node {child!r}"
139            )
140            self.add_dummy_node(child, migration, error_message)
141        if parent not in self.nodes:
142            error_message = (
143                f"Migration {migration} dependencies reference nonexistent"
144                f" parent node {parent!r}"
145            )
146            self.add_dummy_node(parent, migration, error_message)
147        self.node_map[child].add_parent(self.node_map[parent])
148        self.node_map[parent].add_child(self.node_map[child])
149        if not skip_validation:
150            self.validate_consistency()
151
152    def remove_replaced_nodes(
153        self, replacement: tuple[str, str], replaced: list[tuple[str, str]]
154    ) -> None:
155        """
156        Remove each of the `replaced` nodes (when they exist). Any
157        dependencies that were referencing them are changed to reference the
158        `replacement` node instead.
159        """
160        # Cast list of replaced keys to set to speed up lookup later.
161        replaced_set: set[tuple[str, str]] = set(replaced)
162        try:
163            replacement_node = self.node_map[replacement]
164        except KeyError as err:
165            raise NodeNotFoundError(
166                f"Unable to find replacement node {replacement!r}. It was either never added"
167                " to the migration graph, or has been removed.",
168                replacement,
169            ) from err
170        for replaced_key in replaced_set:
171            self.nodes.pop(replaced_key, None)
172            replaced_node = self.node_map.pop(replaced_key, None)
173            if replaced_node:
174                for child in replaced_node.children:
175                    child.parents.remove(replaced_node)
176                    # We don't want to create dependencies between the replaced
177                    # node and the replacement node as this would lead to
178                    # self-referencing on the replacement node at a later iteration.
179                    if child.key not in replaced_set:
180                        replacement_node.add_child(child)
181                        child.add_parent(replacement_node)
182                for parent in replaced_node.parents:
183                    parent.children.remove(replaced_node)
184                    # Again, to avoid self-referencing.
185                    if parent.key not in replaced_set:
186                        replacement_node.add_parent(parent)
187                        parent.add_child(replacement_node)
188
189    def remove_replacement_node(
190        self, replacement: tuple[str, str], replaced: list[tuple[str, str]]
191    ) -> None:
192        """
193        The inverse operation to `remove_replaced_nodes`. Almost. Remove the
194        replacement node `replacement` and remap its child nodes to `replaced`
195        - the list of nodes it would have replaced. Don't remap its parent
196        nodes as they are expected to be correct already.
197        """
198        self.nodes.pop(replacement, None)
199        try:
200            replacement_node = self.node_map.pop(replacement)
201        except KeyError as err:
202            raise NodeNotFoundError(
203                f"Unable to remove replacement node {replacement!r}. It was either never added"
204                " to the migration graph, or has been removed already.",
205                replacement,
206            ) from err
207        replaced_nodes: set[Node] = set()
208        replaced_nodes_parents: set[Node] = set()
209        for key in replaced:
210            replaced_node = self.node_map.get(key)
211            if replaced_node:
212                replaced_nodes.add(replaced_node)
213                replaced_nodes_parents |= replaced_node.parents
214        # We're only interested in the latest replaced node, so filter out
215        # replaced nodes that are parents of other replaced nodes.
216        replaced_nodes -= replaced_nodes_parents
217        for child in replacement_node.children:
218            child.parents.remove(replacement_node)
219            for replaced_node in replaced_nodes:
220                replaced_node.add_child(child)
221                child.add_parent(replaced_node)
222        for parent in replacement_node.parents:
223            parent.children.remove(replacement_node)
224            # NOTE: There is no need to remap parent dependencies as we can
225            # assume the replaced nodes already have the correct ancestry.
226
227    def validate_consistency(self) -> None:
228        """Ensure there are no dummy nodes remaining in the graph."""
229        [n.raise_error() for n in self.node_map.values() if isinstance(n, DummyNode)]
230
231    def forwards_plan(self, target: tuple[str, str]) -> list[tuple[str, str]]:
232        """
233        Given a node, return a list of which previous nodes (dependencies) must
234        be applied, ending with the node itself. This is the list you would
235        follow if applying the migrations to a database.
236        """
237        if target not in self.nodes:
238            raise NodeNotFoundError(f"Node {target!r} not a valid node", target)
239        return self.iterative_dfs(self.node_map[target])
240
241    def iterative_dfs(
242        self, start: Node, forwards: bool = True
243    ) -> list[tuple[str, str]]:
244        """Iterative depth-first search for finding dependencies."""
245        visited: list[tuple[str, str]] = []
246        visited_set: set[Node] = set()
247        stack: list[tuple[Node, bool]] = [(start, False)]
248        while stack:
249            node, processed = stack.pop()
250            if node in visited_set:
251                pass
252            elif processed:
253                visited_set.add(node)
254                visited.append(node.key)
255            else:
256                stack.append((node, True))
257                stack += [
258                    (n, False)
259                    for n in sorted(node.parents if forwards else node.children)
260                ]
261        return visited
262
263    def root_nodes(self, app: str | None = None) -> list[tuple[str, str]]:
264        """
265        Return all root nodes - that is, nodes with no dependencies inside
266        their app. These are the starting point for an app.
267        """
268        roots: set[tuple[str, str]] = set()
269        for node in self.nodes:
270            if all(key[0] != node[0] for key in self.node_map[node].parents) and (
271                not app or app == node[0]
272            ):
273                roots.add(node)
274        return sorted(roots)
275
276    def leaf_nodes(self, app: str | None = None) -> list[tuple[str, str]]:
277        """
278        Return all leaf nodes - that is, nodes with no dependents in their app.
279        These are the "most current" version of an app's schema.
280        Having more than one per app is technically an error, but one that
281        gets handled further up, in the interactive command - it's usually the
282        result of a VCS merge and needs some user input.
283        """
284        leaves: set[tuple[str, str]] = set()
285        for node in self.nodes:
286            if all(key[0] != node[0] for key in self.node_map[node].children) and (
287                not app or app == node[0]
288            ):
289                leaves.add(node)
290        return sorted(leaves)
291
292    def ensure_not_cyclic(self) -> None:
293        # Algo from GvR:
294        # https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html
295        todo: set[tuple[str, str]] = set(self.nodes)
296        while todo:
297            node = todo.pop()
298            stack: list[tuple[str, str]] = [node]
299            while stack:
300                top = stack[-1]
301                for child in self.node_map[top].children:
302                    # Use child.key instead of child to speed up the frequent
303                    # hashing.
304                    node = child.key
305                    if node in stack:
306                        cycle = stack[stack.index(node) :]
307                        raise CircularDependencyError(
308                            ", ".join("{}.{}".format(*n) for n in cycle)
309                        )
310                    if node in todo:
311                        stack.append(node)
312                        todo.remove(node)
313                        break
314                else:
315                    node = stack.pop()
316
317    def __str__(self) -> str:
318        return "Graph: {} nodes, {} edges".format(*self._nodes_and_edges())
319
320    def __repr__(self) -> str:
321        nodes, edges = self._nodes_and_edges()
322        return f"<{self.__class__.__name__}: nodes={nodes}, edges={edges}>"
323
324    def _nodes_and_edges(self) -> tuple[int, int]:
325        return len(self.nodes), sum(
326            len(node.parents) for node in self.node_map.values()
327        )
328
329    def _generate_plan(
330        self, nodes: list[tuple[str, str]], at_end: bool
331    ) -> list[tuple[str, str]]:
332        plan: list[tuple[str, str]] = []
333        for node in nodes:
334            for migration in self.forwards_plan(node):
335                if migration not in plan and (at_end or migration not in nodes):
336                    plan.append(migration)
337        return plan
338
339    def make_state(
340        self,
341        nodes: tuple[str, str] | list[tuple[str, str]] | None = None,
342        at_end: bool = True,
343        real_packages: Any = None,
344    ) -> ProjectState:
345        """
346        Given a migration node or nodes, return a complete ProjectState for it.
347        If at_end is False, return the state before the migration has run.
348        If nodes is not provided, return the overall most current project state.
349        """
350        if nodes is None:
351            nodes = list(self.leaf_nodes())
352        if not nodes:
353            return ProjectState()
354        if not isinstance(nodes[0], tuple):
355            nodes = cast(list[tuple[str, str]], [nodes])
356        assert isinstance(nodes, list)  # Type narrowing after checks above
357        plan = self._generate_plan(nodes, at_end)
358        project_state = ProjectState(real_packages=real_packages)
359        for node in plan:
360            project_state = self.nodes[node].mutate_state(project_state, preserve=False)  # type: ignore[attr-defined]
361        return project_state
362
363    def __contains__(self, node: tuple[str, str]) -> bool:
364        return node in self.nodes