Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import os
  4import re
  5from importlib import import_module
  6from typing import Any
  7
  8from plain.models import migrations
  9from plain.models.migrations.loader import MigrationLoader
 10from plain.models.migrations.migration import SettingsTuple
 11from plain.models.migrations.serializer import serializer_factory
 12from plain.packages import packages_registry
 13from plain.runtime import __version__
 14from plain.utils.inspect import get_func_args
 15from plain.utils.module_loading import module_dir
 16from plain.utils.timezone import now
 17
 18
 19class OperationWriter:
 20    def __init__(self, operation: Any, indentation: int = 2) -> None:
 21        self.operation = operation
 22        self.buff: list[str] = []
 23        self.indentation = indentation
 24
 25    def serialize(self) -> tuple[str, set[str]]:
 26        def _write(_arg_name: str, _arg_value: Any) -> None:
 27            if _arg_name in self.operation.serialization_expand_args and isinstance(
 28                _arg_value, list | tuple | dict
 29            ):
 30                if isinstance(_arg_value, dict):
 31                    self.feed(f"{_arg_name}={{")
 32                    self.indent()
 33                    for key, value in _arg_value.items():
 34                        key_string, key_imports = MigrationWriter.serialize(key)
 35                        arg_string, arg_imports = MigrationWriter.serialize(value)
 36                        args = arg_string.splitlines()
 37                        if len(args) > 1:
 38                            self.feed(f"{key_string}: {args[0]}")
 39                            for arg in args[1:-1]:
 40                                self.feed(arg)
 41                            self.feed(f"{args[-1]},")
 42                        else:
 43                            self.feed(f"{key_string}: {arg_string},")
 44                        imports.update(key_imports)
 45                        imports.update(arg_imports)
 46                    self.unindent()
 47                    self.feed("},")
 48                else:
 49                    self.feed(f"{_arg_name}=[")
 50                    self.indent()
 51                    for item in _arg_value:
 52                        arg_string, arg_imports = MigrationWriter.serialize(item)
 53                        args = arg_string.splitlines()
 54                        if len(args) > 1:
 55                            for arg in args[:-1]:
 56                                self.feed(arg)
 57                            self.feed(f"{args[-1]},")
 58                        else:
 59                            self.feed(f"{arg_string},")
 60                        imports.update(arg_imports)
 61                    self.unindent()
 62                    self.feed("],")
 63            else:
 64                arg_string, arg_imports = MigrationWriter.serialize(_arg_value)
 65                args = arg_string.splitlines()
 66                if len(args) > 1:
 67                    self.feed(f"{_arg_name}={args[0]}")
 68                    for arg in args[1:-1]:
 69                        self.feed(arg)
 70                    self.feed(f"{args[-1]},")
 71                else:
 72                    self.feed(f"{_arg_name}={arg_string},")
 73                imports.update(arg_imports)
 74
 75        imports = set()
 76        name, args, kwargs = self.operation.deconstruct()
 77        operation_args = get_func_args(self.operation.__init__)
 78
 79        # See if this operation is in plain.models.migrations. If it is,
 80        # We can just use the fact we already have that imported,
 81        # otherwise, we need to add an import for the operation class.
 82        if getattr(migrations, name, None) == self.operation.__class__:
 83            self.feed(f"migrations.{name}(")
 84        else:
 85            imports.add(f"import {self.operation.__class__.__module__}")
 86            self.feed(f"{self.operation.__class__.__module__}.{name}(")
 87
 88        self.indent()
 89
 90        for i, arg in enumerate(args):
 91            arg_value = arg
 92            arg_name = operation_args[i]
 93            _write(arg_name, arg_value)
 94
 95        i = len(args)
 96        # Only iterate over remaining arguments
 97        for arg_name in operation_args[i:]:
 98            if arg_name in kwargs:  # Don't sort to maintain signature order
 99                arg_value = kwargs[arg_name]
100                _write(arg_name, arg_value)
101
102        self.unindent()
103        self.feed("),")
104        return self.render(), imports
105
106    def indent(self) -> None:
107        self.indentation += 1
108
109    def unindent(self) -> None:
110        self.indentation -= 1
111
112    def feed(self, line: str) -> None:
113        self.buff.append(" " * (self.indentation * 4) + line)
114
115    def render(self) -> str:
116        return "\n".join(self.buff)
117
118
119class MigrationWriter:
120    """
121    Take a Migration instance and is able to produce the contents
122    of the migration file from it.
123    """
124
125    def __init__(self, migration: Any, include_header: bool = True) -> None:
126        self.migration = migration
127        self.include_header = include_header
128        self.needs_manual_porting = False
129
130    def as_string(self) -> str:
131        """Return a string of the file contents."""
132        items = {
133            "replaces_str": "",
134            "initial_str": "",
135        }
136
137        imports = set()
138
139        # Deconstruct operations
140        operations = []
141        for operation in self.migration.operations:
142            operation_string, operation_imports = OperationWriter(operation).serialize()
143            imports.update(operation_imports)
144            operations.append(operation_string)
145        items["operations"] = "\n".join(operations) + "\n" if operations else ""
146
147        # Format dependencies and write out settings dependencies right
148        dependencies = []
149        for dependency in self.migration.dependencies:
150            if isinstance(dependency, SettingsTuple):
151                dependencies.append(
152                    f"        migrations.settings_dependency(settings.{dependency[1]}),"
153                )
154                imports.add("from plain.runtime import settings")
155            else:
156                dependencies.append(f"        {self.serialize(dependency)[0]},")
157        items["dependencies"] = "\n".join(dependencies) + "\n" if dependencies else ""
158
159        # Format imports nicely, swapping imports of functions from migration files
160        # for comments
161        migration_imports = set()
162        for line in list(imports):
163            if re.match(r"^import (.*)\.\d+[^\s]*$", line):
164                migration_imports.add(line.split("import")[1].strip())
165                imports.remove(line)
166                self.needs_manual_porting = True
167
168        imports.add("from plain.models import migrations")
169
170        # Sort imports by the package / module to be imported (the part after
171        # "from" in "from ... import ..." or after "import" in "import ...").
172        # First group the "import" statements, then "from ... import ...".
173        sorted_imports = sorted(
174            imports, key=lambda i: (i.split()[0] == "from", i.split()[1])
175        )
176        items["imports"] = "\n".join(sorted_imports) + "\n" if imports else ""
177        if migration_imports:
178            items["imports"] += (
179                "\n\n# Functions from the following migrations need manual "
180                "copying.\n# Move them and any dependencies into this file, "
181                "then update the\n# RunPython operations to refer to the local "
182                "versions:\n# {}"
183            ).format("\n# ".join(sorted(migration_imports)))
184        # If there's a replaces, make a string for it
185        if self.migration.replaces:
186            items["replaces_str"] = (
187                f"\n    replaces = {self.serialize(self.migration.replaces)[0]}\n"
188            )
189        # Hinting that goes into comment
190        if self.include_header:
191            items["migration_header"] = MIGRATION_HEADER_TEMPLATE % {
192                "version": __version__,
193                "timestamp": now().strftime("%Y-%m-%d %H:%M"),
194            }
195        else:
196            items["migration_header"] = ""
197
198        if self.migration.initial:
199            items["initial_str"] = "\n    initial = True\n"
200
201        return MIGRATION_TEMPLATE % items
202
203    @property
204    def basedir(self) -> str:
205        migrations_package_name, _ = MigrationLoader.migrations_module(
206            self.migration.package_label
207        )
208
209        if migrations_package_name is None:
210            raise ValueError(
211                f"Plain can't create migrations for app '{self.migration.package_label}' because "
212                "migrations have been disabled via the MIGRATION_MODULES "
213                "setting."
214            )
215
216        # See if we can import the migrations module directly
217        try:
218            migrations_module = import_module(migrations_package_name)
219        except ImportError:
220            pass
221        else:
222            try:
223                return module_dir(migrations_module)
224            except ValueError:
225                pass
226
227        # Alright, see if it's a direct submodule of the app
228        package_config = packages_registry.get_package_config(
229            self.migration.package_label
230        )
231        (
232            maybe_package_name,
233            _,
234            migrations_package_basename,
235        ) = migrations_package_name.rpartition(".")
236        if package_config.name == maybe_package_name:
237            return os.path.join(package_config.path, migrations_package_basename)
238
239        # In case of using MIGRATION_MODULES setting and the custom package
240        # doesn't exist, create one, starting from an existing package
241        existing_dirs, missing_dirs = migrations_package_name.split("."), []
242        while existing_dirs:
243            missing_dirs.insert(0, existing_dirs.pop(-1))
244            try:
245                base_module = import_module(".".join(existing_dirs))
246            except (ImportError, ValueError):
247                continue
248            else:
249                try:
250                    base_dir = module_dir(base_module)
251                except ValueError:
252                    continue
253                else:
254                    break
255        else:
256            raise ValueError(
257                "Could not locate an appropriate location to create "
258                f"migrations package {migrations_package_name}. Make sure the toplevel "
259                "package exists and can be imported."
260            )
261
262        final_dir = os.path.join(base_dir, *missing_dirs)
263        os.makedirs(final_dir, exist_ok=True)
264        for missing_dir in missing_dirs:
265            base_dir = os.path.join(base_dir, missing_dir)
266            with open(os.path.join(base_dir, "__init__.py"), "w"):
267                pass
268
269        return final_dir
270
271    @property
272    def filename(self) -> str:
273        return f"{self.migration.name}.py"
274
275    @property
276    def path(self) -> str:
277        return os.path.join(self.basedir, self.filename)
278
279    @classmethod
280    def serialize(cls, value: Any) -> tuple[str, set[str]]:
281        return serializer_factory(value).serialize()
282
283
284MIGRATION_HEADER_TEMPLATE = """\
285# Generated by Plain %(version)s on %(timestamp)s
286
287"""
288
289
290MIGRATION_TEMPLATE = """\
291%(migration_header)s%(imports)s
292
293class Migration(migrations.Migration):
294%(replaces_str)s%(initial_str)s
295    dependencies = [
296%(dependencies)s\
297    ]
298
299    operations = [
300%(operations)s\
301    ]
302"""