Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from collections.abc import Callable
  4from typing import TYPE_CHECKING, Any
  5
  6from .base import Operation
  7
  8if TYPE_CHECKING:
  9    from plain.models.backends.base.schema import BaseDatabaseSchemaEditor
 10    from plain.models.migrations.state import ProjectState
 11
 12
 13class SeparateDatabaseAndState(Operation):
 14    """
 15    Take two lists of operations - ones that will be used for the database,
 16    and ones that will be used for the state change. This allows operations
 17    that don't support state change to have it applied, or have operations
 18    that affect the state or not the database, or so on.
 19    """
 20
 21    serialization_expand_args = ["database_operations", "state_operations"]
 22
 23    def __init__(
 24        self,
 25        database_operations: list[Operation] | None = None,
 26        state_operations: list[Operation] | None = None,
 27    ) -> None:
 28        self.database_operations = database_operations or []
 29        self.state_operations = state_operations or []
 30
 31    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, list[Operation]]]:
 32        kwargs: dict[str, list[Operation]] = {}
 33        if self.database_operations:
 34            kwargs["database_operations"] = self.database_operations
 35        if self.state_operations:
 36            kwargs["state_operations"] = self.state_operations
 37        return (self.__class__.__qualname__, (), kwargs)
 38
 39    def state_forwards(self, package_label: str, state: ProjectState) -> None:
 40        for state_operation in self.state_operations:
 41            state_operation.state_forwards(package_label, state)
 42
 43    def database_forwards(
 44        self,
 45        package_label: str,
 46        schema_editor: BaseDatabaseSchemaEditor,
 47        from_state: ProjectState,
 48        to_state: ProjectState,
 49    ) -> None:
 50        # We calculate state separately in here since our state functions aren't useful
 51        for database_operation in self.database_operations:
 52            to_state = from_state.clone()
 53            database_operation.state_forwards(package_label, to_state)
 54            database_operation.database_forwards(
 55                package_label, schema_editor, from_state, to_state
 56            )
 57            from_state = to_state
 58
 59    def describe(self) -> str:
 60        return "Custom state/database change combination"
 61
 62
 63class RunSQL(Operation):
 64    """
 65    Run some raw SQL.
 66
 67    Also accept a list of operations that represent the state change effected
 68    by this SQL change, in case it's custom column/table creation/deletion.
 69    """
 70
 71    def __init__(
 72        self,
 73        sql: str
 74        | list[str | tuple[str, list[Any]]]
 75        | tuple[str | tuple[str, list[Any]], ...],
 76        *,
 77        state_operations: list[Operation] | None = None,
 78        elidable: bool = False,
 79    ) -> None:
 80        self.sql = sql
 81        self.state_operations = state_operations or []
 82        self.elidable = elidable
 83
 84    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
 85        kwargs: dict[str, Any] = {
 86            "sql": self.sql,
 87        }
 88        if self.state_operations:
 89            kwargs["state_operations"] = self.state_operations
 90        return (self.__class__.__qualname__, (), kwargs)
 91
 92    def state_forwards(self, package_label: str, state: ProjectState) -> None:
 93        for state_operation in self.state_operations:
 94            state_operation.state_forwards(package_label, state)
 95
 96    def database_forwards(
 97        self,
 98        package_label: str,
 99        schema_editor: BaseDatabaseSchemaEditor,
100        from_state: ProjectState,
101        to_state: ProjectState,
102    ) -> None:
103        self._run_sql(schema_editor, self.sql)
104
105    def describe(self) -> str:
106        return "Raw SQL operation"
107
108    def _run_sql(
109        self,
110        schema_editor: BaseDatabaseSchemaEditor,
111        sqls: str
112        | list[str | tuple[str, list[Any]]]
113        | tuple[str | tuple[str, list[Any]], ...],
114    ) -> None:
115        if isinstance(sqls, list | tuple):
116            for sql_item in sqls:
117                params: list[Any] | None = None
118                sql: str
119                if isinstance(sql_item, list | tuple):
120                    elements = len(sql_item)
121                    if elements == 2:
122                        sql, params = sql_item
123                    else:
124                        raise ValueError("Expected a 2-tuple but got %d" % elements)  # noqa: UP031
125                else:
126                    sql = sql_item
127                schema_editor.execute(sql, params=params)
128        else:
129            # sqls is a str in this branch
130            statements = schema_editor.connection.ops.prepare_sql_script(sqls)
131            for statement in statements:
132                schema_editor.execute(statement, params=None)
133
134
135class RunPython(Operation):
136    """
137    Run Python code in a context suitable for doing versioned ORM operations.
138    """
139
140    reduces_to_sql = False
141
142    def __init__(
143        self,
144        code: Callable[..., Any],
145        *,
146        atomic: bool | None = None,
147        elidable: bool = False,
148    ) -> None:
149        self.atomic = atomic
150        # Forwards code
151        if not callable(code):
152            raise ValueError("RunPython must be supplied with a callable")
153        self.code = code
154        self.elidable = elidable
155
156    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
157        kwargs: dict[str, Any] = {
158            "code": self.code,
159        }
160        if self.atomic is not None:
161            kwargs["atomic"] = self.atomic
162        return (self.__class__.__qualname__, (), kwargs)
163
164    def state_forwards(self, package_label: str, state: Any) -> None:
165        # RunPython objects have no state effect. To add some, combine this
166        # with SeparateDatabaseAndState.
167        pass
168
169    def database_forwards(
170        self,
171        package_label: str,
172        schema_editor: BaseDatabaseSchemaEditor,
173        from_state: ProjectState,
174        to_state: ProjectState,
175    ) -> None:
176        # RunPython has access to all models. Ensure that all models are
177        # reloaded in case any are delayed.
178        from_state.clear_delayed_models_cache()
179        # We now execute the Python code in a context that contains a 'models'
180        # object, representing the versioned models as an app registry.
181        # We could try to override the global cache, but then people will still
182        # use direct imports, so we go with a documentation approach instead.
183        self.code(from_state.models_registry, schema_editor)
184
185    def describe(self) -> str:
186        return "Raw Python operation"