Plain is headed towards 1.0! Subscribe for development updates →

  1import os
  2import re
  3from importlib import import_module
  4
  5from plain.models import migrations
  6from plain.models.migrations.loader import MigrationLoader
  7from plain.models.migrations.serializer import Serializer, serializer_factory
  8from plain.packages import packages
  9from plain.runtime import __version__
 10from plain.utils.inspect import get_func_args
 11from plain.utils.module_loading import module_dir
 12from plain.utils.timezone import now
 13
 14
 15class OperationWriter:
 16    def __init__(self, operation, indentation=2):
 17        self.operation = operation
 18        self.buff = []
 19        self.indentation = indentation
 20
 21    def serialize(self):
 22        def _write(_arg_name, _arg_value):
 23            if _arg_name in self.operation.serialization_expand_args and isinstance(
 24                _arg_value, list | tuple | dict
 25            ):
 26                if isinstance(_arg_value, dict):
 27                    self.feed("%s={" % _arg_name)
 28                    self.indent()
 29                    for key, value in _arg_value.items():
 30                        key_string, key_imports = MigrationWriter.serialize(key)
 31                        arg_string, arg_imports = MigrationWriter.serialize(value)
 32                        args = arg_string.splitlines()
 33                        if len(args) > 1:
 34                            self.feed(f"{key_string}: {args[0]}")
 35                            for arg in args[1:-1]:
 36                                self.feed(arg)
 37                            self.feed("%s," % args[-1])
 38                        else:
 39                            self.feed(f"{key_string}: {arg_string},")
 40                        imports.update(key_imports)
 41                        imports.update(arg_imports)
 42                    self.unindent()
 43                    self.feed("},")
 44                else:
 45                    self.feed("%s=[" % _arg_name)
 46                    self.indent()
 47                    for item in _arg_value:
 48                        arg_string, arg_imports = MigrationWriter.serialize(item)
 49                        args = arg_string.splitlines()
 50                        if len(args) > 1:
 51                            for arg in args[:-1]:
 52                                self.feed(arg)
 53                            self.feed("%s," % args[-1])
 54                        else:
 55                            self.feed("%s," % arg_string)
 56                        imports.update(arg_imports)
 57                    self.unindent()
 58                    self.feed("],")
 59            else:
 60                arg_string, arg_imports = MigrationWriter.serialize(_arg_value)
 61                args = arg_string.splitlines()
 62                if len(args) > 1:
 63                    self.feed(f"{_arg_name}={args[0]}")
 64                    for arg in args[1:-1]:
 65                        self.feed(arg)
 66                    self.feed("%s," % args[-1])
 67                else:
 68                    self.feed(f"{_arg_name}={arg_string},")
 69                imports.update(arg_imports)
 70
 71        imports = set()
 72        name, args, kwargs = self.operation.deconstruct()
 73        operation_args = get_func_args(self.operation.__init__)
 74
 75        # See if this operation is in plain.models.migrations. If it is,
 76        # We can just use the fact we already have that imported,
 77        # otherwise, we need to add an import for the operation class.
 78        if getattr(migrations, name, None) == self.operation.__class__:
 79            self.feed("migrations.%s(" % name)
 80        else:
 81            imports.add("import %s" % (self.operation.__class__.__module__))
 82            self.feed(f"{self.operation.__class__.__module__}.{name}(")
 83
 84        self.indent()
 85
 86        for i, arg in enumerate(args):
 87            arg_value = arg
 88            arg_name = operation_args[i]
 89            _write(arg_name, arg_value)
 90
 91        i = len(args)
 92        # Only iterate over remaining arguments
 93        for arg_name in operation_args[i:]:
 94            if arg_name in kwargs:  # Don't sort to maintain signature order
 95                arg_value = kwargs[arg_name]
 96                _write(arg_name, arg_value)
 97
 98        self.unindent()
 99        self.feed("),")
100        return self.render(), imports
101
102    def indent(self):
103        self.indentation += 1
104
105    def unindent(self):
106        self.indentation -= 1
107
108    def feed(self, line):
109        self.buff.append(" " * (self.indentation * 4) + line)
110
111    def render(self):
112        return "\n".join(self.buff)
113
114
115class MigrationWriter:
116    """
117    Take a Migration instance and is able to produce the contents
118    of the migration file from it.
119    """
120
121    def __init__(self, migration, include_header=True):
122        self.migration = migration
123        self.include_header = include_header
124        self.needs_manual_porting = False
125
126    def as_string(self):
127        """Return a string of the file contents."""
128        items = {
129            "replaces_str": "",
130            "initial_str": "",
131        }
132
133        imports = set()
134
135        # Deconstruct operations
136        operations = []
137        for operation in self.migration.operations:
138            operation_string, operation_imports = OperationWriter(operation).serialize()
139            imports.update(operation_imports)
140            operations.append(operation_string)
141        items["operations"] = "\n".join(operations) + "\n" if operations else ""
142
143        # Format dependencies and write out swappable dependencies right
144        dependencies = []
145        for dependency in self.migration.dependencies:
146            if dependency[0] == "__setting__":
147                dependencies.append(
148                    "        migrations.swappable_dependency(settings.%s),"
149                    % dependency[1]
150                )
151                imports.add("from plain.runtime import settings")
152            else:
153                dependencies.append("        %s," % self.serialize(dependency)[0])
154        items["dependencies"] = "\n".join(dependencies) + "\n" if dependencies else ""
155
156        # Format imports nicely, swapping imports of functions from migration files
157        # for comments
158        migration_imports = set()
159        for line in list(imports):
160            if re.match(r"^import (.*)\.\d+[^\s]*$", line):
161                migration_imports.add(line.split("import")[1].strip())
162                imports.remove(line)
163                self.needs_manual_porting = True
164
165        imports.add("from plain.models import migrations")
166
167        # Sort imports by the package / module to be imported (the part after
168        # "from" in "from ... import ..." or after "import" in "import ...").
169        # First group the "import" statements, then "from ... import ...".
170        sorted_imports = sorted(
171            imports, key=lambda i: (i.split()[0] == "from", i.split()[1])
172        )
173        items["imports"] = "\n".join(sorted_imports) + "\n" if imports else ""
174        if migration_imports:
175            items["imports"] += (
176                "\n\n# Functions from the following migrations need manual "
177                "copying.\n# Move them and any dependencies into this file, "
178                "then update the\n# RunPython operations to refer to the local "
179                "versions:\n# %s"
180            ) % "\n# ".join(sorted(migration_imports))
181        # If there's a replaces, make a string for it
182        if self.migration.replaces:
183            items["replaces_str"] = (
184                "\n    replaces = %s\n" % self.serialize(self.migration.replaces)[0]
185            )
186        # Hinting that goes into comment
187        if self.include_header:
188            items["migration_header"] = MIGRATION_HEADER_TEMPLATE % {
189                "version": __version__,
190                "timestamp": now().strftime("%Y-%m-%d %H:%M"),
191            }
192        else:
193            items["migration_header"] = ""
194
195        if self.migration.initial:
196            items["initial_str"] = "\n    initial = True\n"
197
198        return MIGRATION_TEMPLATE % items
199
200    @property
201    def basedir(self):
202        migrations_package_name, _ = MigrationLoader.migrations_module(
203            self.migration.package_label
204        )
205
206        if migrations_package_name is None:
207            raise ValueError(
208                "Plain can't create migrations for app '%s' because "
209                "migrations have been disabled via the MIGRATION_MODULES "
210                "setting." % self.migration.package_label
211            )
212
213        # See if we can import the migrations module directly
214        try:
215            migrations_module = import_module(migrations_package_name)
216        except ImportError:
217            pass
218        else:
219            try:
220                return module_dir(migrations_module)
221            except ValueError:
222                pass
223
224        # Alright, see if it's a direct submodule of the app
225        package_config = packages.get_package_config(self.migration.package_label)
226        (
227            maybe_package_name,
228            _,
229            migrations_package_basename,
230        ) = migrations_package_name.rpartition(".")
231        if package_config.name == maybe_package_name:
232            return os.path.join(package_config.path, migrations_package_basename)
233
234        # In case of using MIGRATION_MODULES setting and the custom package
235        # doesn't exist, create one, starting from an existing package
236        existing_dirs, missing_dirs = migrations_package_name.split("."), []
237        while existing_dirs:
238            missing_dirs.insert(0, existing_dirs.pop(-1))
239            try:
240                base_module = import_module(".".join(existing_dirs))
241            except (ImportError, ValueError):
242                continue
243            else:
244                try:
245                    base_dir = module_dir(base_module)
246                except ValueError:
247                    continue
248                else:
249                    break
250        else:
251            raise ValueError(
252                "Could not locate an appropriate location to create "
253                "migrations package %s. Make sure the toplevel "
254                "package exists and can be imported." % migrations_package_name
255            )
256
257        final_dir = os.path.join(base_dir, *missing_dirs)
258        os.makedirs(final_dir, exist_ok=True)
259        for missing_dir in missing_dirs:
260            base_dir = os.path.join(base_dir, missing_dir)
261            with open(os.path.join(base_dir, "__init__.py"), "w"):
262                pass
263
264        return final_dir
265
266    @property
267    def filename(self):
268        return "%s.py" % self.migration.name
269
270    @property
271    def path(self):
272        return os.path.join(self.basedir, self.filename)
273
274    @classmethod
275    def serialize(cls, value):
276        return serializer_factory(value).serialize()
277
278    @classmethod
279    def register_serializer(cls, type_, serializer):
280        Serializer.register(type_, serializer)
281
282    @classmethod
283    def unregister_serializer(cls, type_):
284        Serializer.unregister(type_)
285
286
287MIGRATION_HEADER_TEMPLATE = """\
288# Generated by Plain %(version)s on %(timestamp)s
289
290"""
291
292
293MIGRATION_TEMPLATE = """\
294%(migration_header)s%(imports)s
295
296class Migration(migrations.Migration):
297%(replaces_str)s%(initial_str)s
298    dependencies = [
299%(dependencies)s\
300    ]
301
302    operations = [
303%(operations)s\
304    ]
305"""