1from functools import total_ordering
2
3from plain.models.migrations.state import ProjectState
4
5from .exceptions import CircularDependencyError, NodeNotFoundError
6
7
8@total_ordering
9class Node:
10 """
11 A single node in the migration graph. Contains direct links to adjacent
12 nodes in either direction.
13 """
14
15 def __init__(self, key):
16 self.key = key
17 self.children = set()
18 self.parents = set()
19
20 def __eq__(self, other):
21 return self.key == other
22
23 def __lt__(self, other):
24 return self.key < other
25
26 def __hash__(self):
27 return hash(self.key)
28
29 def __getitem__(self, item):
30 return self.key[item]
31
32 def __str__(self):
33 return str(self.key)
34
35 def __repr__(self):
36 return f"<{self.__class__.__name__}: ({self.key[0]!r}, {self.key[1]!r})>"
37
38 def add_child(self, child):
39 self.children.add(child)
40
41 def add_parent(self, parent):
42 self.parents.add(parent)
43
44
45class DummyNode(Node):
46 """
47 A node that doesn't correspond to a migration file on disk.
48 (A squashed migration that was removed, for example.)
49
50 After the migration graph is processed, all dummy nodes should be removed.
51 If there are any left, a nonexistent dependency error is raised.
52 """
53
54 def __init__(self, key, origin, error_message):
55 super().__init__(key)
56 self.origin = origin
57 self.error_message = error_message
58
59 def raise_error(self):
60 raise NodeNotFoundError(self.error_message, self.key, origin=self.origin)
61
62
63class MigrationGraph:
64 """
65 Represent the digraph of all migrations in a project.
66
67 Each migration is a node, and each dependency is an edge. There are
68 no implicit dependencies between numbered migrations - the numbering is
69 merely a convention to aid file listing. Every new numbered migration
70 has a declared dependency to the previous number, meaning that VCS
71 branch merges can be detected and resolved.
72
73 Migrations files can be marked as replacing another set of migrations -
74 this is to support the "squash" feature. The graph handler isn't responsible
75 for these; instead, the code to load them in here should examine the
76 migration files and if the replaced migrations are all either unapplied
77 or not present, it should ignore the replaced ones, load in just the
78 replacing migration, and repoint any dependencies that pointed to the
79 replaced migrations to point to the replacing one.
80
81 A node should be a tuple: (app_path, migration_name). The tree special-cases
82 things within an app - namely, root nodes and leaf nodes ignore dependencies
83 to other packages.
84 """
85
86 def __init__(self):
87 self.node_map = {}
88 self.nodes = {}
89
90 def add_node(self, key, migration):
91 assert key not in self.node_map
92 node = Node(key)
93 self.node_map[key] = node
94 self.nodes[key] = migration
95
96 def add_dummy_node(self, key, origin, error_message):
97 node = DummyNode(key, origin, error_message)
98 self.node_map[key] = node
99 self.nodes[key] = None
100
101 def add_dependency(self, migration, child, parent, skip_validation=False):
102 """
103 This may create dummy nodes if they don't yet exist. If
104 `skip_validation=True`, validate_consistency() should be called
105 afterward.
106 """
107 if child not in self.nodes:
108 error_message = (
109 f"Migration {migration} dependencies reference nonexistent"
110 f" child node {child!r}"
111 )
112 self.add_dummy_node(child, migration, error_message)
113 if parent not in self.nodes:
114 error_message = (
115 f"Migration {migration} dependencies reference nonexistent"
116 f" parent node {parent!r}"
117 )
118 self.add_dummy_node(parent, migration, error_message)
119 self.node_map[child].add_parent(self.node_map[parent])
120 self.node_map[parent].add_child(self.node_map[child])
121 if not skip_validation:
122 self.validate_consistency()
123
124 def remove_replaced_nodes(self, replacement, replaced):
125 """
126 Remove each of the `replaced` nodes (when they exist). Any
127 dependencies that were referencing them are changed to reference the
128 `replacement` node instead.
129 """
130 # Cast list of replaced keys to set to speed up lookup later.
131 replaced = set(replaced)
132 try:
133 replacement_node = self.node_map[replacement]
134 except KeyError as err:
135 raise NodeNotFoundError(
136 "Unable to find replacement node {!r}. It was either never added"
137 " to the migration graph, or has been removed.".format(replacement),
138 replacement,
139 ) from err
140 for replaced_key in replaced:
141 self.nodes.pop(replaced_key, None)
142 replaced_node = self.node_map.pop(replaced_key, None)
143 if replaced_node:
144 for child in replaced_node.children:
145 child.parents.remove(replaced_node)
146 # We don't want to create dependencies between the replaced
147 # node and the replacement node as this would lead to
148 # self-referencing on the replacement node at a later iteration.
149 if child.key not in replaced:
150 replacement_node.add_child(child)
151 child.add_parent(replacement_node)
152 for parent in replaced_node.parents:
153 parent.children.remove(replaced_node)
154 # Again, to avoid self-referencing.
155 if parent.key not in replaced:
156 replacement_node.add_parent(parent)
157 parent.add_child(replacement_node)
158
159 def remove_replacement_node(self, replacement, replaced):
160 """
161 The inverse operation to `remove_replaced_nodes`. Almost. Remove the
162 replacement node `replacement` and remap its child nodes to `replaced`
163 - the list of nodes it would have replaced. Don't remap its parent
164 nodes as they are expected to be correct already.
165 """
166 self.nodes.pop(replacement, None)
167 try:
168 replacement_node = self.node_map.pop(replacement)
169 except KeyError as err:
170 raise NodeNotFoundError(
171 "Unable to remove replacement node {!r}. It was either never added"
172 " to the migration graph, or has been removed already.".format(
173 replacement
174 ),
175 replacement,
176 ) from err
177 replaced_nodes = set()
178 replaced_nodes_parents = set()
179 for key in replaced:
180 replaced_node = self.node_map.get(key)
181 if replaced_node:
182 replaced_nodes.add(replaced_node)
183 replaced_nodes_parents |= replaced_node.parents
184 # We're only interested in the latest replaced node, so filter out
185 # replaced nodes that are parents of other replaced nodes.
186 replaced_nodes -= replaced_nodes_parents
187 for child in replacement_node.children:
188 child.parents.remove(replacement_node)
189 for replaced_node in replaced_nodes:
190 replaced_node.add_child(child)
191 child.add_parent(replaced_node)
192 for parent in replacement_node.parents:
193 parent.children.remove(replacement_node)
194 # NOTE: There is no need to remap parent dependencies as we can
195 # assume the replaced nodes already have the correct ancestry.
196
197 def validate_consistency(self):
198 """Ensure there are no dummy nodes remaining in the graph."""
199 [n.raise_error() for n in self.node_map.values() if isinstance(n, DummyNode)]
200
201 def forwards_plan(self, target):
202 """
203 Given a node, return a list of which previous nodes (dependencies) must
204 be applied, ending with the node itself. This is the list you would
205 follow if applying the migrations to a database.
206 """
207 if target not in self.nodes:
208 raise NodeNotFoundError(f"Node {target!r} not a valid node", target)
209 return self.iterative_dfs(self.node_map[target])
210
211 def backwards_plan(self, target):
212 """
213 Given a node, return a list of which dependent nodes (dependencies)
214 must be unapplied, ending with the node itself. This is the list you
215 would follow if removing the migrations from a database.
216 """
217 if target not in self.nodes:
218 raise NodeNotFoundError(f"Node {target!r} not a valid node", target)
219 return self.iterative_dfs(self.node_map[target], forwards=False)
220
221 def iterative_dfs(self, start, forwards=True):
222 """Iterative depth-first search for finding dependencies."""
223 visited = []
224 visited_set = set()
225 stack = [(start, False)]
226 while stack:
227 node, processed = stack.pop()
228 if node in visited_set:
229 pass
230 elif processed:
231 visited_set.add(node)
232 visited.append(node.key)
233 else:
234 stack.append((node, True))
235 stack += [
236 (n, False)
237 for n in sorted(node.parents if forwards else node.children)
238 ]
239 return visited
240
241 def root_nodes(self, app=None):
242 """
243 Return all root nodes - that is, nodes with no dependencies inside
244 their app. These are the starting point for an app.
245 """
246 roots = set()
247 for node in self.nodes:
248 if all(key[0] != node[0] for key in self.node_map[node].parents) and (
249 not app or app == node[0]
250 ):
251 roots.add(node)
252 return sorted(roots)
253
254 def leaf_nodes(self, app=None):
255 """
256 Return all leaf nodes - that is, nodes with no dependents in their app.
257 These are the "most current" version of an app's schema.
258 Having more than one per app is technically an error, but one that
259 gets handled further up, in the interactive command - it's usually the
260 result of a VCS merge and needs some user input.
261 """
262 leaves = set()
263 for node in self.nodes:
264 if all(key[0] != node[0] for key in self.node_map[node].children) and (
265 not app or app == node[0]
266 ):
267 leaves.add(node)
268 return sorted(leaves)
269
270 def ensure_not_cyclic(self):
271 # Algo from GvR:
272 # https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html
273 todo = set(self.nodes)
274 while todo:
275 node = todo.pop()
276 stack = [node]
277 while stack:
278 top = stack[-1]
279 for child in self.node_map[top].children:
280 # Use child.key instead of child to speed up the frequent
281 # hashing.
282 node = child.key
283 if node in stack:
284 cycle = stack[stack.index(node) :]
285 raise CircularDependencyError(
286 ", ".join("{}.{}".format(*n) for n in cycle)
287 )
288 if node in todo:
289 stack.append(node)
290 todo.remove(node)
291 break
292 else:
293 node = stack.pop()
294
295 def __str__(self):
296 return "Graph: {} nodes, {} edges".format(*self._nodes_and_edges())
297
298 def __repr__(self):
299 nodes, edges = self._nodes_and_edges()
300 return f"<{self.__class__.__name__}: nodes={nodes}, edges={edges}>"
301
302 def _nodes_and_edges(self):
303 return len(self.nodes), sum(
304 len(node.parents) for node in self.node_map.values()
305 )
306
307 def _generate_plan(self, nodes, at_end):
308 plan = []
309 for node in nodes:
310 for migration in self.forwards_plan(node):
311 if migration not in plan and (at_end or migration not in nodes):
312 plan.append(migration)
313 return plan
314
315 def make_state(self, nodes=None, at_end=True, real_packages=None):
316 """
317 Given a migration node or nodes, return a complete ProjectState for it.
318 If at_end is False, return the state before the migration has run.
319 If nodes is not provided, return the overall most current project state.
320 """
321 if nodes is None:
322 nodes = list(self.leaf_nodes())
323 if not nodes:
324 return ProjectState()
325 if not isinstance(nodes[0], tuple):
326 nodes = [nodes]
327 plan = self._generate_plan(nodes, at_end)
328 project_state = ProjectState(real_packages=real_packages)
329 for node in plan:
330 project_state = self.nodes[node].mutate_state(project_state, preserve=False)
331 return project_state
332
333 def __contains__(self, node):
334 return node in self.nodes