1from __future__ import annotations
2
3from collections.abc import Callable
4from typing import TYPE_CHECKING, Any
5
6from .base import Operation
7
8if TYPE_CHECKING:
9 from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
10 from plain.models.migrations.state import ProjectState
11
12
13class SeparateDatabaseAndState(Operation):
14 """
15 Take two lists of operations - ones that will be used for the database,
16 and ones that will be used for the state change. This allows operations
17 that don't support state change to have it applied, or have operations
18 that affect the state or not the database, or so on.
19 """
20
21 serialization_expand_args = ["database_operations", "state_operations"]
22
23 def __init__(
24 self,
25 database_operations: list[Operation] | None = None,
26 state_operations: list[Operation] | None = None,
27 ) -> None:
28 self.database_operations = database_operations or []
29 self.state_operations = state_operations or []
30
31 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, list[Operation]]]:
32 kwargs: dict[str, list[Operation]] = {}
33 if self.database_operations:
34 kwargs["database_operations"] = self.database_operations
35 if self.state_operations:
36 kwargs["state_operations"] = self.state_operations
37 return (self.__class__.__qualname__, (), kwargs)
38
39 def state_forwards(self, package_label: str, state: ProjectState) -> None:
40 for state_operation in self.state_operations:
41 state_operation.state_forwards(package_label, state)
42
43 def database_forwards(
44 self,
45 package_label: str,
46 schema_editor: BaseDatabaseSchemaEditor,
47 from_state: ProjectState,
48 to_state: ProjectState,
49 ) -> None:
50 # We calculate state separately in here since our state functions aren't useful
51 for database_operation in self.database_operations:
52 to_state = from_state.clone()
53 database_operation.state_forwards(package_label, to_state)
54 database_operation.database_forwards(
55 package_label, schema_editor, from_state, to_state
56 )
57 from_state = to_state
58
59 def describe(self) -> str:
60 return "Custom state/database change combination"
61
62
63class RunSQL(Operation):
64 """
65 Run some raw SQL.
66
67 Also accept a list of operations that represent the state change effected
68 by this SQL change, in case it's custom column/table creation/deletion.
69 """
70
71 def __init__(
72 self,
73 sql: str
74 | list[str | tuple[str, list[Any]]]
75 | tuple[str | tuple[str, list[Any]], ...],
76 *,
77 state_operations: list[Operation] | None = None,
78 elidable: bool = False,
79 ) -> None:
80 self.sql = sql
81 self.state_operations = state_operations or []
82 self.elidable = elidable
83
84 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
85 kwargs: dict[str, Any] = {
86 "sql": self.sql,
87 }
88 if self.state_operations:
89 kwargs["state_operations"] = self.state_operations
90 return (self.__class__.__qualname__, (), kwargs)
91
92 def state_forwards(self, package_label: str, state: ProjectState) -> None:
93 for state_operation in self.state_operations:
94 state_operation.state_forwards(package_label, state)
95
96 def database_forwards(
97 self,
98 package_label: str,
99 schema_editor: BaseDatabaseSchemaEditor,
100 from_state: ProjectState,
101 to_state: ProjectState,
102 ) -> None:
103 self._run_sql(schema_editor, self.sql)
104
105 def describe(self) -> str:
106 return "Raw SQL operation"
107
108 def _run_sql(
109 self,
110 schema_editor: BaseDatabaseSchemaEditor,
111 sqls: str
112 | list[str | tuple[str, list[Any]]]
113 | tuple[str | tuple[str, list[Any]], ...],
114 ) -> None:
115 if isinstance(sqls, list | tuple):
116 for sql_item in sqls:
117 params: list[Any] | None = None
118 sql: str
119 if isinstance(sql_item, list | tuple):
120 elements = len(sql_item)
121 if elements == 2:
122 sql, params = sql_item
123 else:
124 raise ValueError("Expected a 2-tuple but got %d" % elements) # noqa: UP031
125 else:
126 sql = sql_item
127 schema_editor.execute(sql, params=params)
128 else:
129 # sqls is a str in this branch
130 statements = schema_editor.connection.ops.prepare_sql_script(sqls)
131 for statement in statements:
132 schema_editor.execute(statement, params=None)
133
134
135class RunPython(Operation):
136 """
137 Run Python code in a context suitable for doing versioned ORM operations.
138 """
139
140 reduces_to_sql = False
141
142 def __init__(
143 self,
144 code: Callable[..., Any],
145 *,
146 atomic: bool | None = None,
147 elidable: bool = False,
148 ) -> None:
149 self.atomic = atomic
150 # Forwards code
151 if not callable(code):
152 raise ValueError("RunPython must be supplied with a callable")
153 self.code = code
154 self.elidable = elidable
155
156 def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
157 kwargs: dict[str, Any] = {
158 "code": self.code,
159 }
160 if self.atomic is not None:
161 kwargs["atomic"] = self.atomic
162 return (self.__class__.__qualname__, (), kwargs)
163
164 def state_forwards(self, package_label: str, state: Any) -> None:
165 # RunPython objects have no state effect. To add some, combine this
166 # with SeparateDatabaseAndState.
167 pass
168
169 def database_forwards(
170 self,
171 package_label: str,
172 schema_editor: BaseDatabaseSchemaEditor,
173 from_state: ProjectState,
174 to_state: ProjectState,
175 ) -> None:
176 # RunPython has access to all models. Ensure that all models are
177 # reloaded in case any are delayed.
178 from_state.clear_delayed_models_cache()
179 # We now execute the Python code in a context that contains a 'models'
180 # object, representing the versioned models as an app registry.
181 # We could try to override the global cache, but then people will still
182 # use direct imports, so we go with a documentation approach instead.
183 self.code(from_state.models_registry, schema_editor)
184
185 def describe(self) -> str:
186 return "Raw Python operation"