1from __future__ import annotations
2
3from collections.abc import Callable
4from typing import TYPE_CHECKING, Any, cast
5
6from .base import Operation
7
8if TYPE_CHECKING:
9 from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
10 from plain.models.migrations.state import ProjectState
11
12
13class SeparateDatabaseAndState(Operation):
14 """
15 Take two lists of operations - ones that will be used for the database,
16 and ones that will be used for the state change. This allows operations
17 that don't support state change to have it applied, or have operations
18 that affect the state or not the database, or so on.
19 """
20
21 serialization_expand_args = ["database_operations", "state_operations"]
22
23 def __init__(
24 self,
25 database_operations: list[Operation] | None = None,
26 state_operations: list[Operation] | None = None,
27 ) -> None:
28 self.database_operations = database_operations or []
29 self.state_operations = state_operations or []
30
31 def deconstruct(self) -> tuple[str, list[Any], dict[str, list[Operation]]]:
32 kwargs: dict[str, list[Operation]] = {}
33 if self.database_operations:
34 kwargs["database_operations"] = self.database_operations
35 if self.state_operations:
36 kwargs["state_operations"] = self.state_operations
37 return (self.__class__.__qualname__, [], kwargs)
38
39 def state_forwards(self, package_label: str, state: ProjectState) -> None:
40 for state_operation in self.state_operations:
41 state_operation.state_forwards(package_label, state)
42
43 def database_forwards(
44 self,
45 package_label: str,
46 schema_editor: BaseDatabaseSchemaEditor,
47 from_state: ProjectState,
48 to_state: ProjectState,
49 ) -> None:
50 # We calculate state separately in here since our state functions aren't useful
51 for database_operation in self.database_operations:
52 to_state = from_state.clone()
53 database_operation.state_forwards(package_label, to_state)
54 database_operation.database_forwards(
55 package_label, schema_editor, from_state, to_state
56 )
57 from_state = to_state
58
59 def describe(self) -> str:
60 return "Custom state/database change combination"
61
62
63class RunSQL(Operation):
64 """
65 Run some raw SQL.
66
67 Also accept a list of operations that represent the state change effected
68 by this SQL change, in case it's custom column/table creation/deletion.
69 """
70
71 def __init__(
72 self,
73 sql: str
74 | list[str | tuple[str, list[Any]]]
75 | tuple[str | tuple[str, list[Any]], ...],
76 *,
77 state_operations: list[Operation] | None = None,
78 elidable: bool = False,
79 ) -> None:
80 self.sql = sql
81 self.state_operations = state_operations or []
82 self.elidable = elidable
83
84 def deconstruct(self) -> tuple[str, list[Any], dict[str, Any]]:
85 kwargs: dict[str, Any] = {
86 "sql": self.sql,
87 }
88 if self.state_operations:
89 kwargs["state_operations"] = self.state_operations
90 return (self.__class__.__qualname__, [], kwargs)
91
92 def state_forwards(self, package_label: str, state: ProjectState) -> None:
93 for state_operation in self.state_operations:
94 state_operation.state_forwards(package_label, state)
95
96 def database_forwards(
97 self,
98 package_label: str,
99 schema_editor: BaseDatabaseSchemaEditor,
100 from_state: ProjectState,
101 to_state: ProjectState,
102 ) -> None:
103 self._run_sql(schema_editor, self.sql)
104
105 def describe(self) -> str:
106 return "Raw SQL operation"
107
108 def _run_sql(
109 self,
110 schema_editor: BaseDatabaseSchemaEditor,
111 sqls: str
112 | list[str | tuple[str, list[Any]]]
113 | tuple[str | tuple[str, list[Any]], ...],
114 ) -> None:
115 if isinstance(sqls, list | tuple):
116 for sql_item in sqls:
117 params: list[Any] | None = None
118 sql: str
119 if isinstance(sql_item, list | tuple):
120 elements = len(sql_item)
121 if elements == 2:
122 sql, params = sql_item # type: ignore[misc]
123 else:
124 raise ValueError("Expected a 2-tuple but got %d" % elements) # noqa: UP031
125 else:
126 sql = sql_item
127 schema_editor.execute(sql, params=params)
128 else:
129 # sqls is a str in this branch
130 statements = schema_editor.connection.ops.prepare_sql_script(
131 cast(str, sqls)
132 )
133 for statement in statements:
134 schema_editor.execute(statement, params=None)
135
136
137class RunPython(Operation):
138 """
139 Run Python code in a context suitable for doing versioned ORM operations.
140 """
141
142 reduces_to_sql = False
143
144 def __init__(
145 self,
146 code: Callable[..., Any],
147 *,
148 atomic: bool | None = None,
149 elidable: bool = False,
150 ) -> None:
151 self.atomic = atomic
152 # Forwards code
153 if not callable(code):
154 raise ValueError("RunPython must be supplied with a callable")
155 self.code = code
156 self.elidable = elidable
157
158 def deconstruct(self) -> tuple[str, list[Any], dict[str, Any]]:
159 kwargs: dict[str, Any] = {
160 "code": self.code,
161 }
162 if self.atomic is not None:
163 kwargs["atomic"] = self.atomic
164 return (self.__class__.__qualname__, [], kwargs)
165
166 def state_forwards(self, package_label: str, state: Any) -> None:
167 # RunPython objects have no state effect. To add some, combine this
168 # with SeparateDatabaseAndState.
169 pass
170
171 def database_forwards(
172 self,
173 package_label: str,
174 schema_editor: BaseDatabaseSchemaEditor,
175 from_state: ProjectState,
176 to_state: ProjectState,
177 ) -> None:
178 # RunPython has access to all models. Ensure that all models are
179 # reloaded in case any are delayed.
180 from_state.clear_delayed_models_cache()
181 # We now execute the Python code in a context that contains a 'models'
182 # object, representing the versioned models as an app registry.
183 # We could try to override the global cache, but then people will still
184 # use direct imports, so we go with a documentation approach instead.
185 self.code(from_state.models_registry, schema_editor)
186
187 def describe(self) -> str:
188 return "Raw Python operation"