Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from collections.abc import Callable
  4from typing import TYPE_CHECKING, Any, cast
  5
  6from .base import Operation
  7
  8if TYPE_CHECKING:
  9    from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
 10    from plain.models.migrations.state import ProjectState
 11
 12
 13class SeparateDatabaseAndState(Operation):
 14    """
 15    Take two lists of operations - ones that will be used for the database,
 16    and ones that will be used for the state change. This allows operations
 17    that don't support state change to have it applied, or have operations
 18    that affect the state or not the database, or so on.
 19    """
 20
 21    serialization_expand_args = ["database_operations", "state_operations"]
 22
 23    def __init__(
 24        self,
 25        database_operations: list[Operation] | None = None,
 26        state_operations: list[Operation] | None = None,
 27    ) -> None:
 28        self.database_operations = database_operations or []
 29        self.state_operations = state_operations or []
 30
 31    def deconstruct(self) -> tuple[str, list[Any], dict[str, list[Operation]]]:
 32        kwargs: dict[str, list[Operation]] = {}
 33        if self.database_operations:
 34            kwargs["database_operations"] = self.database_operations
 35        if self.state_operations:
 36            kwargs["state_operations"] = self.state_operations
 37        return (self.__class__.__qualname__, [], kwargs)
 38
 39    def state_forwards(self, package_label: str, state: ProjectState) -> None:
 40        for state_operation in self.state_operations:
 41            state_operation.state_forwards(package_label, state)
 42
 43    def database_forwards(
 44        self,
 45        package_label: str,
 46        schema_editor: BaseDatabaseSchemaEditor,
 47        from_state: ProjectState,
 48        to_state: ProjectState,
 49    ) -> None:
 50        # We calculate state separately in here since our state functions aren't useful
 51        for database_operation in self.database_operations:
 52            to_state = from_state.clone()
 53            database_operation.state_forwards(package_label, to_state)
 54            database_operation.database_forwards(
 55                package_label, schema_editor, from_state, to_state
 56            )
 57            from_state = to_state
 58
 59    def describe(self) -> str:
 60        return "Custom state/database change combination"
 61
 62
 63class RunSQL(Operation):
 64    """
 65    Run some raw SQL.
 66
 67    Also accept a list of operations that represent the state change effected
 68    by this SQL change, in case it's custom column/table creation/deletion.
 69    """
 70
 71    def __init__(
 72        self,
 73        sql: str
 74        | list[str | tuple[str, list[Any]]]
 75        | tuple[str | tuple[str, list[Any]], ...],
 76        *,
 77        state_operations: list[Operation] | None = None,
 78        elidable: bool = False,
 79    ) -> None:
 80        self.sql = sql
 81        self.state_operations = state_operations or []
 82        self.elidable = elidable
 83
 84    def deconstruct(self) -> tuple[str, list[Any], dict[str, Any]]:
 85        kwargs: dict[str, Any] = {
 86            "sql": self.sql,
 87        }
 88        if self.state_operations:
 89            kwargs["state_operations"] = self.state_operations
 90        return (self.__class__.__qualname__, [], kwargs)
 91
 92    def state_forwards(self, package_label: str, state: ProjectState) -> None:
 93        for state_operation in self.state_operations:
 94            state_operation.state_forwards(package_label, state)
 95
 96    def database_forwards(
 97        self,
 98        package_label: str,
 99        schema_editor: BaseDatabaseSchemaEditor,
100        from_state: ProjectState,
101        to_state: ProjectState,
102    ) -> None:
103        self._run_sql(schema_editor, self.sql)
104
105    def describe(self) -> str:
106        return "Raw SQL operation"
107
108    def _run_sql(
109        self,
110        schema_editor: BaseDatabaseSchemaEditor,
111        sqls: str
112        | list[str | tuple[str, list[Any]]]
113        | tuple[str | tuple[str, list[Any]], ...],
114    ) -> None:
115        if isinstance(sqls, list | tuple):
116            for sql_item in sqls:
117                params: list[Any] | None = None
118                sql: str
119                if isinstance(sql_item, list | tuple):
120                    elements = len(sql_item)
121                    if elements == 2:
122                        sql, params = sql_item  # type: ignore[misc]
123                    else:
124                        raise ValueError("Expected a 2-tuple but got %d" % elements)  # noqa: UP031
125                else:
126                    sql = sql_item
127                schema_editor.execute(sql, params=params)
128        else:
129            # sqls is a str in this branch
130            statements = schema_editor.connection.ops.prepare_sql_script(
131                cast(str, sqls)
132            )
133            for statement in statements:
134                schema_editor.execute(statement, params=None)
135
136
137class RunPython(Operation):
138    """
139    Run Python code in a context suitable for doing versioned ORM operations.
140    """
141
142    reduces_to_sql = False
143
144    def __init__(
145        self,
146        code: Callable[..., Any],
147        *,
148        atomic: bool | None = None,
149        elidable: bool = False,
150    ) -> None:
151        self.atomic = atomic
152        # Forwards code
153        if not callable(code):
154            raise ValueError("RunPython must be supplied with a callable")
155        self.code = code
156        self.elidable = elidable
157
158    def deconstruct(self) -> tuple[str, list[Any], dict[str, Any]]:
159        kwargs: dict[str, Any] = {
160            "code": self.code,
161        }
162        if self.atomic is not None:
163            kwargs["atomic"] = self.atomic
164        return (self.__class__.__qualname__, [], kwargs)
165
166    def state_forwards(self, package_label: str, state: Any) -> None:
167        # RunPython objects have no state effect. To add some, combine this
168        # with SeparateDatabaseAndState.
169        pass
170
171    def database_forwards(
172        self,
173        package_label: str,
174        schema_editor: BaseDatabaseSchemaEditor,
175        from_state: ProjectState,
176        to_state: ProjectState,
177    ) -> None:
178        # RunPython has access to all models. Ensure that all models are
179        # reloaded in case any are delayed.
180        from_state.clear_delayed_models_cache()
181        # We now execute the Python code in a context that contains a 'models'
182        # object, representing the versioned models as an app registry.
183        # We could try to override the global cache, but then people will still
184        # use direct imports, so we go with a documentation approach instead.
185        self.code(from_state.models_registry, schema_editor)
186
187    def describe(self) -> str:
188        return "Raw Python operation"