1from __future__ import annotations
2
3from functools import total_ordering
4from typing import TYPE_CHECKING, Any, cast
5
6from plain.postgres.migrations.state import ProjectState
7
8from .exceptions import CircularDependencyError, NodeNotFoundError
9
10if TYPE_CHECKING:
11 from plain.postgres.migrations.migration import Migration
12
13
14@total_ordering
15class Node:
16 """
17 A single node in the migration graph. Contains direct links to adjacent
18 nodes in either direction.
19 """
20
21 def __init__(self, key: tuple[str, str]):
22 self.key = key
23 self.children: set[Node] = set()
24 self.parents: set[Node] = set()
25
26 def __eq__(self, other: object) -> bool:
27 if isinstance(other, Node):
28 return self.key == other.key
29 return self.key == other
30
31 def __lt__(self, other: object) -> bool:
32 if isinstance(other, Node):
33 return self.key < other.key
34 if isinstance(other, tuple):
35 return self.key < cast(tuple[str, str], other)
36 return NotImplemented
37
38 def __hash__(self) -> int:
39 return hash(self.key)
40
41 def __getitem__(self, item: int) -> str:
42 return self.key[item]
43
44 def __str__(self) -> str:
45 return str(self.key)
46
47 def __repr__(self) -> str:
48 return f"<{self.__class__.__name__}: ({self.key[0]!r}, {self.key[1]!r})>"
49
50 def add_child(self, child: Node) -> None:
51 self.children.add(child)
52
53 def add_parent(self, parent: Node) -> None:
54 self.parents.add(parent)
55
56
57class DummyNode(Node):
58 """
59 A node that doesn't correspond to a migration file on disk.
60 (A squashed migration that was removed, for example.)
61
62 After the migration graph is processed, all dummy nodes should be removed.
63 If there are any left, a nonexistent dependency error is raised.
64 """
65
66 def __init__(
67 self,
68 key: tuple[str, str],
69 origin: Migration | tuple[str, str] | None,
70 error_message: str,
71 ):
72 super().__init__(key)
73 self.origin = origin
74 self.error_message = error_message
75
76 def raise_error(self) -> None:
77 raise NodeNotFoundError(self.error_message, self.key, origin=self.origin)
78
79
80class MigrationGraph:
81 """
82 Represent the digraph of all migrations in a project.
83
84 Each migration is a node, and each dependency is an edge. There are
85 no implicit dependencies between numbered migrations - the numbering is
86 merely a convention to aid file listing. Every new numbered migration
87 has a declared dependency to the previous number, meaning that VCS
88 branch merges can be detected and resolved.
89
90 Migrations files can be marked as replacing another set of migrations -
91 this is to support the "squash" feature. The graph handler isn't responsible
92 for these; instead, the code to load them in here should examine the
93 migration files and if the replaced migrations are all either unapplied
94 or not present, it should ignore the replaced ones, load in just the
95 replacing migration, and repoint any dependencies that pointed to the
96 replaced migrations to point to the replacing one.
97
98 A node should be a tuple: (app_path, migration_name). The tree special-cases
99 things within an app - namely, root nodes and leaf nodes ignore dependencies
100 to other packages.
101 """
102
103 def __init__(self):
104 self.node_map: dict[tuple[str, str], Node] = {}
105 self.nodes: dict[tuple[str, str], Migration | None] = {}
106
107 def add_node(self, key: tuple[str, str], migration: Migration) -> None:
108 assert key not in self.node_map
109 node = Node(key)
110 self.node_map[key] = node
111 self.nodes[key] = migration
112
113 def add_dummy_node(
114 self,
115 key: tuple[str, str],
116 origin: Migration | tuple[str, str] | None,
117 error_message: str,
118 ) -> None:
119 node = DummyNode(key, origin, error_message)
120 self.node_map[key] = node
121 self.nodes[key] = None
122
123 def add_dependency(
124 self,
125 migration: Migration | tuple[str, str] | None,
126 child: tuple[str, str],
127 parent: tuple[str, str],
128 skip_validation: bool = False,
129 ) -> None:
130 """
131 This may create dummy nodes if they don't yet exist. If
132 `skip_validation=True`, validate_consistency() should be called
133 afterward.
134 """
135 if child not in self.nodes:
136 error_message = (
137 f"Migration {migration} dependencies reference nonexistent"
138 f" child node {child!r}"
139 )
140 self.add_dummy_node(child, migration, error_message)
141 if parent not in self.nodes:
142 error_message = (
143 f"Migration {migration} dependencies reference nonexistent"
144 f" parent node {parent!r}"
145 )
146 self.add_dummy_node(parent, migration, error_message)
147 self.node_map[child].add_parent(self.node_map[parent])
148 self.node_map[parent].add_child(self.node_map[child])
149 if not skip_validation:
150 self.validate_consistency()
151
152 def remove_replaced_nodes(
153 self, replacement: tuple[str, str], replaced: list[tuple[str, str]]
154 ) -> None:
155 """
156 Remove each of the `replaced` nodes (when they exist). Any
157 dependencies that were referencing them are changed to reference the
158 `replacement` node instead.
159 """
160 # Cast list of replaced keys to set to speed up lookup later.
161 replaced_set: set[tuple[str, str]] = set(replaced)
162 try:
163 replacement_node = self.node_map[replacement]
164 except KeyError as err:
165 raise NodeNotFoundError(
166 f"Unable to find replacement node {replacement!r}. It was either never added"
167 " to the migration graph, or has been removed.",
168 replacement,
169 ) from err
170 for replaced_key in replaced_set:
171 self.nodes.pop(replaced_key, None)
172 replaced_node = self.node_map.pop(replaced_key, None)
173 if replaced_node:
174 for child in replaced_node.children:
175 child.parents.remove(replaced_node)
176 # We don't want to create dependencies between the replaced
177 # node and the replacement node as this would lead to
178 # self-referencing on the replacement node at a later iteration.
179 if child.key not in replaced_set:
180 replacement_node.add_child(child)
181 child.add_parent(replacement_node)
182 for parent in replaced_node.parents:
183 parent.children.remove(replaced_node)
184 # Again, to avoid self-referencing.
185 if parent.key not in replaced_set:
186 replacement_node.add_parent(parent)
187 parent.add_child(replacement_node)
188
189 def remove_replacement_node(
190 self, replacement: tuple[str, str], replaced: list[tuple[str, str]]
191 ) -> None:
192 """
193 The inverse operation to `remove_replaced_nodes`. Almost. Remove the
194 replacement node `replacement` and remap its child nodes to `replaced`
195 - the list of nodes it would have replaced. Don't remap its parent
196 nodes as they are expected to be correct already.
197 """
198 self.nodes.pop(replacement, None)
199 try:
200 replacement_node = self.node_map.pop(replacement)
201 except KeyError as err:
202 raise NodeNotFoundError(
203 f"Unable to remove replacement node {replacement!r}. It was either never added"
204 " to the migration graph, or has been removed already.",
205 replacement,
206 ) from err
207 replaced_nodes: set[Node] = set()
208 replaced_nodes_parents: set[Node] = set()
209 for key in replaced:
210 replaced_node = self.node_map.get(key)
211 if replaced_node:
212 replaced_nodes.add(replaced_node)
213 replaced_nodes_parents |= replaced_node.parents
214 # We're only interested in the latest replaced node, so filter out
215 # replaced nodes that are parents of other replaced nodes.
216 replaced_nodes -= replaced_nodes_parents
217 for child in replacement_node.children:
218 child.parents.remove(replacement_node)
219 for replaced_node in replaced_nodes:
220 replaced_node.add_child(child)
221 child.add_parent(replaced_node)
222 for parent in replacement_node.parents:
223 parent.children.remove(replacement_node)
224 # NOTE: There is no need to remap parent dependencies as we can
225 # assume the replaced nodes already have the correct ancestry.
226
227 def validate_consistency(self) -> None:
228 """Ensure there are no dummy nodes remaining in the graph."""
229 [n.raise_error() for n in self.node_map.values() if isinstance(n, DummyNode)]
230
231 def forwards_plan(self, target: tuple[str, str]) -> list[tuple[str, str]]:
232 """
233 Given a node, return a list of which previous nodes (dependencies) must
234 be applied, ending with the node itself. This is the list you would
235 follow if applying the migrations to a database.
236 """
237 if target not in self.nodes:
238 raise NodeNotFoundError(f"Node {target!r} not a valid node", target)
239 return self.iterative_dfs(self.node_map[target])
240
241 def iterative_dfs(
242 self, start: Node, forwards: bool = True
243 ) -> list[tuple[str, str]]:
244 """Iterative depth-first search for finding dependencies."""
245 visited: list[tuple[str, str]] = []
246 visited_set: set[Node] = set()
247 stack: list[tuple[Node, bool]] = [(start, False)]
248 while stack:
249 node, processed = stack.pop()
250 if node in visited_set:
251 pass
252 elif processed:
253 visited_set.add(node)
254 visited.append(node.key)
255 else:
256 stack.append((node, True))
257 stack += [
258 (n, False)
259 for n in sorted(node.parents if forwards else node.children)
260 ]
261 return visited
262
263 def root_nodes(self, app: str | None = None) -> list[tuple[str, str]]:
264 """
265 Return all root nodes - that is, nodes with no dependencies inside
266 their app. These are the starting point for an app.
267 """
268 roots: set[tuple[str, str]] = set()
269 for node in self.nodes:
270 if all(key[0] != node[0] for key in self.node_map[node].parents) and (
271 not app or app == node[0]
272 ):
273 roots.add(node)
274 return sorted(roots)
275
276 def leaf_nodes(self, app: str | None = None) -> list[tuple[str, str]]:
277 """
278 Return all leaf nodes - that is, nodes with no dependents in their app.
279 These are the "most current" version of an app's schema.
280 Having more than one per app is technically an error, but one that
281 gets handled further up, in the interactive command - it's usually the
282 result of a VCS merge and needs some user input.
283 """
284 leaves: set[tuple[str, str]] = set()
285 for node in self.nodes:
286 if all(key[0] != node[0] for key in self.node_map[node].children) and (
287 not app or app == node[0]
288 ):
289 leaves.add(node)
290 return sorted(leaves)
291
292 def ensure_not_cyclic(self) -> None:
293 # Algo from GvR:
294 # https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html
295 todo: set[tuple[str, str]] = set(self.nodes)
296 while todo:
297 node = todo.pop()
298 stack: list[tuple[str, str]] = [node]
299 while stack:
300 top = stack[-1]
301 for child in self.node_map[top].children:
302 # Use child.key instead of child to speed up the frequent
303 # hashing.
304 node = child.key
305 if node in stack:
306 cycle = stack[stack.index(node) :]
307 raise CircularDependencyError(
308 ", ".join("{}.{}".format(*n) for n in cycle)
309 )
310 if node in todo:
311 stack.append(node)
312 todo.remove(node)
313 break
314 else:
315 node = stack.pop()
316
317 def __str__(self) -> str:
318 return "Graph: {} nodes, {} edges".format(*self._nodes_and_edges())
319
320 def __repr__(self) -> str:
321 nodes, edges = self._nodes_and_edges()
322 return f"<{self.__class__.__name__}: nodes={nodes}, edges={edges}>"
323
324 def _nodes_and_edges(self) -> tuple[int, int]:
325 return len(self.nodes), sum(
326 len(node.parents) for node in self.node_map.values()
327 )
328
329 def _generate_plan(
330 self, nodes: list[tuple[str, str]], at_end: bool
331 ) -> list[tuple[str, str]]:
332 plan: list[tuple[str, str]] = []
333 for node in nodes:
334 for migration in self.forwards_plan(node):
335 if migration not in plan and (at_end or migration not in nodes):
336 plan.append(migration)
337 return plan
338
339 def make_state(
340 self,
341 nodes: tuple[str, str] | list[tuple[str, str]] | None = None,
342 at_end: bool = True,
343 real_packages: Any = None,
344 ) -> ProjectState:
345 """
346 Given a migration node or nodes, return a complete ProjectState for it.
347 If at_end is False, return the state before the migration has run.
348 If nodes is not provided, return the overall most current project state.
349 """
350 if nodes is None:
351 nodes = list(self.leaf_nodes())
352 if not nodes:
353 return ProjectState()
354 if not isinstance(nodes[0], tuple):
355 nodes = cast(list[tuple[str, str]], [nodes])
356 assert isinstance(nodes, list) # Type narrowing after checks above
357 plan = self._generate_plan(nodes, at_end)
358 project_state = ProjectState(real_packages=real_packages)
359 for node in plan:
360 project_state = self.nodes[node].mutate_state(project_state, preserve=False) # type: ignore[attr-defined]
361 return project_state
362
363 def __contains__(self, node: tuple[str, str]) -> bool:
364 return node in self.nodes