1from __future__ import annotations
  2
  3from collections.abc import Callable
  4from typing import TYPE_CHECKING, Any
  5
  6from .base import Operation
  7
  8if TYPE_CHECKING:
  9    from plain.models.migrations.state import ProjectState
 10    from plain.models.postgres.schema import DatabaseSchemaEditor
 11
 12
 13class SeparateDatabaseAndState(Operation):
 14    """
 15    Take two lists of operations - ones that will be used for the database,
 16    and ones that will be used for the state change. This allows operations
 17    that don't support state change to have it applied, or have operations
 18    that affect the state or not the database, or so on.
 19    """
 20
 21    serialization_expand_args = ["database_operations", "state_operations"]
 22
 23    def __init__(
 24        self,
 25        database_operations: list[Operation] | None = None,
 26        state_operations: list[Operation] | None = None,
 27    ) -> None:
 28        self.database_operations = database_operations or []
 29        self.state_operations = state_operations or []
 30
 31    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, list[Operation]]]:
 32        kwargs: dict[str, list[Operation]] = {}
 33        if self.database_operations:
 34            kwargs["database_operations"] = self.database_operations
 35        if self.state_operations:
 36            kwargs["state_operations"] = self.state_operations
 37        return (self.__class__.__qualname__, (), kwargs)
 38
 39    def state_forwards(self, package_label: str, state: ProjectState) -> None:
 40        for state_operation in self.state_operations:
 41            state_operation.state_forwards(package_label, state)
 42
 43    def database_forwards(
 44        self,
 45        package_label: str,
 46        schema_editor: DatabaseSchemaEditor,
 47        from_state: ProjectState,
 48        to_state: ProjectState,
 49    ) -> None:
 50        # We calculate state separately in here since our state functions aren't useful
 51        for database_operation in self.database_operations:
 52            to_state = from_state.clone()
 53            database_operation.state_forwards(package_label, to_state)
 54            database_operation.database_forwards(
 55                package_label, schema_editor, from_state, to_state
 56            )
 57            from_state = to_state
 58
 59    def describe(self) -> str:
 60        return "Custom state/database change combination"
 61
 62
 63class RunSQL(Operation):
 64    """
 65    Run some raw SQL.
 66
 67    Also accept a list of operations that represent the state change effected
 68    by this SQL change, in case it's custom column/table creation/deletion.
 69    """
 70
 71    def __init__(
 72        self,
 73        sql: str
 74        | list[str | tuple[str, list[Any]]]
 75        | tuple[str | tuple[str, list[Any]], ...],
 76        *,
 77        state_operations: list[Operation] | None = None,
 78        elidable: bool = False,
 79    ) -> None:
 80        self.sql = sql
 81        self.state_operations = state_operations or []
 82        self.elidable = elidable
 83
 84    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
 85        kwargs: dict[str, Any] = {
 86            "sql": self.sql,
 87        }
 88        if self.state_operations:
 89            kwargs["state_operations"] = self.state_operations
 90        return (self.__class__.__qualname__, (), kwargs)
 91
 92    def state_forwards(self, package_label: str, state: ProjectState) -> None:
 93        for state_operation in self.state_operations:
 94            state_operation.state_forwards(package_label, state)
 95
 96    def database_forwards(
 97        self,
 98        package_label: str,
 99        schema_editor: DatabaseSchemaEditor,
100        from_state: ProjectState,
101        to_state: ProjectState,
102    ) -> None:
103        self._run_sql(schema_editor, self.sql)
104
105    def describe(self) -> str:
106        return "Raw SQL operation"
107
108    def _run_sql(
109        self,
110        schema_editor: DatabaseSchemaEditor,
111        sqls: str
112        | list[str | tuple[str, list[Any]]]
113        | tuple[str | tuple[str, list[Any]], ...],
114    ) -> None:
115        if isinstance(sqls, list | tuple):
116            for sql_item in sqls:
117                params: list[Any] | None = None
118                sql: str
119                if isinstance(sql_item, list | tuple):
120                    elements = len(sql_item)
121                    if elements == 2:
122                        sql, params = sql_item
123                    else:
124                        raise ValueError("Expected a 2-tuple but got %d" % elements)  # noqa: UP031
125                else:
126                    sql = sql_item
127                schema_editor.execute(sql, params=params)
128        else:
129            # PostgreSQL can handle multi-statement scripts in a single execute call
130            schema_editor.execute(sqls, params=None)
131
132
133class RunPython(Operation):
134    """
135    Run Python code in a context suitable for doing versioned ORM operations.
136    """
137
138    reduces_to_sql = False
139
140    def __init__(
141        self,
142        code: Callable[..., Any],
143        *,
144        atomic: bool | None = None,
145        elidable: bool = False,
146    ) -> None:
147        self.atomic = atomic
148        # Forwards code
149        if not callable(code):
150            raise ValueError("RunPython must be supplied with a callable")
151        self.code = code
152        self.elidable = elidable
153
154    def deconstruct(self) -> tuple[str, tuple[Any, ...], dict[str, Any]]:
155        kwargs: dict[str, Any] = {
156            "code": self.code,
157        }
158        if self.atomic is not None:
159            kwargs["atomic"] = self.atomic
160        return (self.__class__.__qualname__, (), kwargs)
161
162    def state_forwards(self, package_label: str, state: Any) -> None:
163        # RunPython objects have no state effect. To add some, combine this
164        # with SeparateDatabaseAndState.
165        pass
166
167    def database_forwards(
168        self,
169        package_label: str,
170        schema_editor: DatabaseSchemaEditor,
171        from_state: ProjectState,
172        to_state: ProjectState,
173    ) -> None:
174        # RunPython has access to all models. Ensure that all models are
175        # reloaded in case any are delayed.
176        from_state.clear_delayed_models_cache()
177        # We now execute the Python code in a context that contains a 'models'
178        # object, representing the versioned models as an app registry.
179        # We could try to override the global cache, but then people will still
180        # use direct imports, so we go with a documentation approach instead.
181        self.code(from_state.models_registry, schema_editor)
182
183    def describe(self) -> str:
184        return "Raw Python operation"