Plain is headed towards 1.0! Subscribe for development updates →

  1import builtins
  2import collections.abc
  3import datetime
  4import decimal
  5import enum
  6import functools
  7import math
  8import os
  9import pathlib
 10import re
 11import types
 12import uuid
 13
 14from plain import models
 15from plain.models.migrations.operations.base import Operation
 16from plain.models.migrations.utils import COMPILED_REGEX_TYPE, RegexObject
 17from plain.runtime import SettingsReference
 18from plain.utils.functional import LazyObject, Promise
 19
 20
 21class BaseSerializer:
 22    def __init__(self, value):
 23        self.value = value
 24
 25    def serialize(self):
 26        raise NotImplementedError(
 27            "Subclasses of BaseSerializer must implement the serialize() method."
 28        )
 29
 30
 31class BaseSequenceSerializer(BaseSerializer):
 32    def _format(self):
 33        raise NotImplementedError(
 34            "Subclasses of BaseSequenceSerializer must implement the _format() method."
 35        )
 36
 37    def serialize(self):
 38        imports = set()
 39        strings = []
 40        for item in self.value:
 41            item_string, item_imports = serializer_factory(item).serialize()
 42            imports.update(item_imports)
 43            strings.append(item_string)
 44        value = self._format()
 45        return value % (", ".join(strings)), imports
 46
 47
 48class BaseSimpleSerializer(BaseSerializer):
 49    def serialize(self):
 50        return repr(self.value), set()
 51
 52
 53class ChoicesSerializer(BaseSerializer):
 54    def serialize(self):
 55        return serializer_factory(self.value.value).serialize()
 56
 57
 58class DateTimeSerializer(BaseSerializer):
 59    """For datetime.*, except datetime.datetime."""
 60
 61    def serialize(self):
 62        return repr(self.value), {"import datetime"}
 63
 64
 65class DatetimeDatetimeSerializer(BaseSerializer):
 66    """For datetime.datetime."""
 67
 68    def serialize(self):
 69        if self.value.tzinfo is not None and self.value.tzinfo != datetime.UTC:
 70            self.value = self.value.astimezone(datetime.UTC)
 71        imports = ["import datetime"]
 72        return repr(self.value), set(imports)
 73
 74
 75class DecimalSerializer(BaseSerializer):
 76    def serialize(self):
 77        return repr(self.value), {"from decimal import Decimal"}
 78
 79
 80class DeconstructableSerializer(BaseSerializer):
 81    @staticmethod
 82    def serialize_deconstructed(path, args, kwargs):
 83        name, imports = DeconstructableSerializer._serialize_path(path)
 84        strings = []
 85        for arg in args:
 86            arg_string, arg_imports = serializer_factory(arg).serialize()
 87            strings.append(arg_string)
 88            imports.update(arg_imports)
 89        for kw, arg in sorted(kwargs.items()):
 90            arg_string, arg_imports = serializer_factory(arg).serialize()
 91            imports.update(arg_imports)
 92            strings.append(f"{kw}={arg_string}")
 93        return "{}({})".format(name, ", ".join(strings)), imports
 94
 95    @staticmethod
 96    def _serialize_path(path):
 97        module, name = path.rsplit(".", 1)
 98        if module == "plain.models":
 99            imports = {"from plain import models"}
100            name = f"models.{name}"
101        else:
102            imports = {f"import {module}"}
103            name = path
104        return name, imports
105
106    def serialize(self):
107        return self.serialize_deconstructed(*self.value.deconstruct())
108
109
110class DictionarySerializer(BaseSerializer):
111    def serialize(self):
112        imports = set()
113        strings = []
114        for k, v in sorted(self.value.items()):
115            k_string, k_imports = serializer_factory(k).serialize()
116            v_string, v_imports = serializer_factory(v).serialize()
117            imports.update(k_imports)
118            imports.update(v_imports)
119            strings.append((k_string, v_string))
120        return "{{{}}}".format(", ".join(f"{k}: {v}" for k, v in strings)), imports
121
122
123class EnumSerializer(BaseSerializer):
124    def serialize(self):
125        enum_class = self.value.__class__
126        module = enum_class.__module__
127        if issubclass(enum_class, enum.Flag):
128            members = list(self.value)
129        else:
130            members = (self.value,)
131        return (
132            " | ".join(
133                [
134                    f"{module}.{enum_class.__qualname__}[{item.name!r}]"
135                    for item in members
136                ]
137            ),
138            {f"import {module}"},
139        )
140
141
142class FloatSerializer(BaseSimpleSerializer):
143    def serialize(self):
144        if math.isnan(self.value) or math.isinf(self.value):
145            return f'float("{self.value}")', set()
146        return super().serialize()
147
148
149class FrozensetSerializer(BaseSequenceSerializer):
150    def _format(self):
151        return "frozenset([%s])"
152
153
154class FunctionTypeSerializer(BaseSerializer):
155    def serialize(self):
156        if getattr(self.value, "__self__", None) and isinstance(
157            self.value.__self__, type
158        ):
159            klass = self.value.__self__
160            module = klass.__module__
161            return f"{module}.{klass.__name__}.{self.value.__name__}", {
162                f"import {module}"
163            }
164        # Further error checking
165        if self.value.__name__ == "<lambda>":
166            raise ValueError("Cannot serialize function: lambda")
167        if self.value.__module__ is None:
168            raise ValueError(f"Cannot serialize function {self.value!r}: No module")
169
170        module_name = self.value.__module__
171
172        if "<" not in self.value.__qualname__:  # Qualname can include <locals>
173            return f"{module_name}.{self.value.__qualname__}", {
174                f"import {self.value.__module__}"
175            }
176
177        raise ValueError(
178            f"Could not find function {self.value.__name__} in {module_name}.\n"
179        )
180
181
182class FunctoolsPartialSerializer(BaseSerializer):
183    def serialize(self):
184        # Serialize functools.partial() arguments
185        func_string, func_imports = serializer_factory(self.value.func).serialize()
186        args_string, args_imports = serializer_factory(self.value.args).serialize()
187        keywords_string, keywords_imports = serializer_factory(
188            self.value.keywords
189        ).serialize()
190        # Add any imports needed by arguments
191        imports = {"import functools", *func_imports, *args_imports, *keywords_imports}
192        return (
193            f"functools.{self.value.__class__.__name__}({func_string}, *{args_string}, **{keywords_string})",
194            imports,
195        )
196
197
198class IterableSerializer(BaseSerializer):
199    def serialize(self):
200        imports = set()
201        strings = []
202        for item in self.value:
203            item_string, item_imports = serializer_factory(item).serialize()
204            imports.update(item_imports)
205            strings.append(item_string)
206        # When len(strings)==0, the empty iterable should be serialized as
207        # "()", not "(,)" because (,) is invalid Python syntax.
208        value = "(%s)" if len(strings) != 1 else "(%s,)"
209        return value % (", ".join(strings)), imports
210
211
212class ModelFieldSerializer(DeconstructableSerializer):
213    def serialize(self):
214        attr_name, path, args, kwargs = self.value.deconstruct()
215        return self.serialize_deconstructed(path, args, kwargs)
216
217
218class ModelManagerSerializer(DeconstructableSerializer):
219    def serialize(self):
220        as_manager, manager_path, qs_path, args, kwargs = self.value.deconstruct()
221        if as_manager:
222            name, imports = self._serialize_path(qs_path)
223            return f"{name}.as_manager()", imports
224        else:
225            return self.serialize_deconstructed(manager_path, args, kwargs)
226
227
228class OperationSerializer(BaseSerializer):
229    def serialize(self):
230        from plain.models.migrations.writer import OperationWriter
231
232        string, imports = OperationWriter(self.value, indentation=0).serialize()
233        # Nested operation, trailing comma is handled in upper OperationWriter._write()
234        return string.rstrip(","), imports
235
236
237class PathLikeSerializer(BaseSerializer):
238    def serialize(self):
239        return repr(os.fspath(self.value)), {}
240
241
242class PathSerializer(BaseSerializer):
243    def serialize(self):
244        # Convert concrete paths to pure paths to avoid issues with migrations
245        # generated on one platform being used on a different platform.
246        prefix = "Pure" if isinstance(self.value, pathlib.Path) else ""
247        return f"pathlib.{prefix}{self.value!r}", {"import pathlib"}
248
249
250class RegexSerializer(BaseSerializer):
251    def serialize(self):
252        regex_pattern, pattern_imports = serializer_factory(
253            self.value.pattern
254        ).serialize()
255        # Turn off default implicit flags (e.g. re.U) because regexes with the
256        # same implicit and explicit flags aren't equal.
257        flags = self.value.flags ^ re.compile("").flags
258        regex_flags, flag_imports = serializer_factory(flags).serialize()
259        imports = {"import re", *pattern_imports, *flag_imports}
260        args = [regex_pattern]
261        if flags:
262            args.append(regex_flags)
263        return "re.compile({})".format(", ".join(args)), imports
264
265
266class SequenceSerializer(BaseSequenceSerializer):
267    def _format(self):
268        return "[%s]"
269
270
271class SetSerializer(BaseSequenceSerializer):
272    def _format(self):
273        # Serialize as a set literal except when value is empty because {}
274        # is an empty dict.
275        return "{%s}" if self.value else "set(%s)"
276
277
278class SettingsReferenceSerializer(BaseSerializer):
279    def serialize(self):
280        return f"settings.{self.value.setting_name}", {
281            "from plain.runtime import settings"
282        }
283
284
285class TupleSerializer(BaseSequenceSerializer):
286    def _format(self):
287        # When len(value)==0, the empty tuple should be serialized as "()",
288        # not "(,)" because (,) is invalid Python syntax.
289        return "(%s)" if len(self.value) != 1 else "(%s,)"
290
291
292class TypeSerializer(BaseSerializer):
293    def serialize(self):
294        special_cases = [
295            (models.Model, "models.Model", ["from plain import models"]),
296            (types.NoneType, "types.NoneType", ["import types"]),
297        ]
298        for case, string, imports in special_cases:
299            if case is self.value:
300                return string, set(imports)
301        if hasattr(self.value, "__module__"):
302            module = self.value.__module__
303            if module == builtins.__name__:
304                return self.value.__name__, set()
305            else:
306                return f"{module}.{self.value.__qualname__}", {f"import {module}"}
307
308
309class UUIDSerializer(BaseSerializer):
310    def serialize(self):
311        return f"uuid.{repr(self.value)}", {"import uuid"}
312
313
314class Serializer:
315    _registry = {
316        # Some of these are order-dependent.
317        frozenset: FrozensetSerializer,
318        list: SequenceSerializer,
319        set: SetSerializer,
320        tuple: TupleSerializer,
321        dict: DictionarySerializer,
322        models.Choices: ChoicesSerializer,
323        enum.Enum: EnumSerializer,
324        datetime.datetime: DatetimeDatetimeSerializer,
325        (datetime.date, datetime.timedelta, datetime.time): DateTimeSerializer,
326        SettingsReference: SettingsReferenceSerializer,
327        float: FloatSerializer,
328        (bool, int, types.NoneType, bytes, str, range): BaseSimpleSerializer,
329        decimal.Decimal: DecimalSerializer,
330        (functools.partial, functools.partialmethod): FunctoolsPartialSerializer,
331        (
332            types.FunctionType,
333            types.BuiltinFunctionType,
334            types.MethodType,
335        ): FunctionTypeSerializer,
336        collections.abc.Iterable: IterableSerializer,
337        (COMPILED_REGEX_TYPE, RegexObject): RegexSerializer,
338        uuid.UUID: UUIDSerializer,
339        pathlib.PurePath: PathSerializer,
340        os.PathLike: PathLikeSerializer,
341    }
342
343    @classmethod
344    def register(cls, type_, serializer):
345        if not issubclass(serializer, BaseSerializer):
346            raise ValueError(
347                f"'{serializer.__name__}' must inherit from 'BaseSerializer'."
348            )
349        cls._registry[type_] = serializer
350
351
352def serializer_factory(value):
353    if isinstance(value, Promise):
354        value = str(value)
355    elif isinstance(value, LazyObject):
356        # The unwrapped value is returned as the first item of the arguments
357        # tuple.
358        value = value.__reduce__()[1][0]
359
360    if isinstance(value, models.Field):
361        return ModelFieldSerializer(value)
362    if isinstance(value, models.manager.BaseManager):
363        return ModelManagerSerializer(value)
364    if isinstance(value, Operation):
365        return OperationSerializer(value)
366    if isinstance(value, type):
367        return TypeSerializer(value)
368    # Anything that knows how to deconstruct itself.
369    if hasattr(value, "deconstruct"):
370        return DeconstructableSerializer(value)
371    for type_, serializer_cls in Serializer._registry.items():
372        if isinstance(value, type_):
373            return serializer_cls(value)
374    raise ValueError(
375        f"Cannot serialize: {value!r}\nThere are some values Plain cannot serialize into "
376        "migration files."
377    )