Plain is headed towards 1.0! Subscribe for development updates →

  1import builtins
  2import collections.abc
  3import datetime
  4import decimal
  5import enum
  6import functools
  7import math
  8import os
  9import pathlib
 10import re
 11import types
 12import uuid
 13
 14from plain import models
 15from plain.models.migrations.operations.base import Operation
 16from plain.models.migrations.utils import COMPILED_REGEX_TYPE, RegexObject
 17from plain.runtime.user_settings import SettingsReference
 18from plain.utils.functional import LazyObject, Promise
 19
 20
 21class BaseSerializer:
 22    def __init__(self, value):
 23        self.value = value
 24
 25    def serialize(self):
 26        raise NotImplementedError(
 27            "Subclasses of BaseSerializer must implement the serialize() method."
 28        )
 29
 30
 31class BaseSequenceSerializer(BaseSerializer):
 32    def _format(self):
 33        raise NotImplementedError(
 34            "Subclasses of BaseSequenceSerializer must implement the _format() method."
 35        )
 36
 37    def serialize(self):
 38        imports = set()
 39        strings = []
 40        for item in self.value:
 41            item_string, item_imports = serializer_factory(item).serialize()
 42            imports.update(item_imports)
 43            strings.append(item_string)
 44        value = self._format()
 45        return value % (", ".join(strings)), imports
 46
 47
 48class BaseSimpleSerializer(BaseSerializer):
 49    def serialize(self):
 50        return repr(self.value), set()
 51
 52
 53class ChoicesSerializer(BaseSerializer):
 54    def serialize(self):
 55        return serializer_factory(self.value.value).serialize()
 56
 57
 58class DateTimeSerializer(BaseSerializer):
 59    """For datetime.*, except datetime.datetime."""
 60
 61    def serialize(self):
 62        return repr(self.value), {"import datetime"}
 63
 64
 65class DatetimeDatetimeSerializer(BaseSerializer):
 66    """For datetime.datetime."""
 67
 68    def serialize(self):
 69        if self.value.tzinfo is not None and self.value.tzinfo != datetime.timezone.utc:
 70            self.value = self.value.astimezone(datetime.timezone.utc)
 71        imports = ["import datetime"]
 72        return repr(self.value), set(imports)
 73
 74
 75class DecimalSerializer(BaseSerializer):
 76    def serialize(self):
 77        return repr(self.value), {"from decimal import Decimal"}
 78
 79
 80class DeconstructableSerializer(BaseSerializer):
 81    @staticmethod
 82    def serialize_deconstructed(path, args, kwargs):
 83        name, imports = DeconstructableSerializer._serialize_path(path)
 84        strings = []
 85        for arg in args:
 86            arg_string, arg_imports = serializer_factory(arg).serialize()
 87            strings.append(arg_string)
 88            imports.update(arg_imports)
 89        for kw, arg in sorted(kwargs.items()):
 90            arg_string, arg_imports = serializer_factory(arg).serialize()
 91            imports.update(arg_imports)
 92            strings.append(f"{kw}={arg_string}")
 93        return "{}({})".format(name, ", ".join(strings)), imports
 94
 95    @staticmethod
 96    def _serialize_path(path):
 97        module, name = path.rsplit(".", 1)
 98        if module == "plain.models":
 99            imports = {"from plain import models"}
100            name = "models.%s" % name
101        else:
102            imports = {"import %s" % module}
103            name = path
104        return name, imports
105
106    def serialize(self):
107        return self.serialize_deconstructed(*self.value.deconstruct())
108
109
110class DictionarySerializer(BaseSerializer):
111    def serialize(self):
112        imports = set()
113        strings = []
114        for k, v in sorted(self.value.items()):
115            k_string, k_imports = serializer_factory(k).serialize()
116            v_string, v_imports = serializer_factory(v).serialize()
117            imports.update(k_imports)
118            imports.update(v_imports)
119            strings.append((k_string, v_string))
120        return "{%s}" % (", ".join(f"{k}: {v}" for k, v in strings)), imports
121
122
123class EnumSerializer(BaseSerializer):
124    def serialize(self):
125        enum_class = self.value.__class__
126        module = enum_class.__module__
127        if issubclass(enum_class, enum.Flag):
128            members = list(self.value)
129        else:
130            members = (self.value,)
131        return (
132            " | ".join(
133                [
134                    f"{module}.{enum_class.__qualname__}[{item.name!r}]"
135                    for item in members
136                ]
137            ),
138            {"import %s" % module},
139        )
140
141
142class FloatSerializer(BaseSimpleSerializer):
143    def serialize(self):
144        if math.isnan(self.value) or math.isinf(self.value):
145            return f'float("{self.value}")', set()
146        return super().serialize()
147
148
149class FrozensetSerializer(BaseSequenceSerializer):
150    def _format(self):
151        return "frozenset([%s])"
152
153
154class FunctionTypeSerializer(BaseSerializer):
155    def serialize(self):
156        if getattr(self.value, "__self__", None) and isinstance(
157            self.value.__self__, type
158        ):
159            klass = self.value.__self__
160            module = klass.__module__
161            return f"{module}.{klass.__name__}.{self.value.__name__}", {
162                "import %s" % module
163            }
164        # Further error checking
165        if self.value.__name__ == "<lambda>":
166            raise ValueError("Cannot serialize function: lambda")
167        if self.value.__module__ is None:
168            raise ValueError("Cannot serialize function %r: No module" % self.value)
169
170        module_name = self.value.__module__
171
172        if "<" not in self.value.__qualname__:  # Qualname can include <locals>
173            return f"{module_name}.{self.value.__qualname__}", {
174                "import %s" % self.value.__module__
175            }
176
177        raise ValueError(
178            f"Could not find function {self.value.__name__} in {module_name}.\n"
179        )
180
181
182class FunctoolsPartialSerializer(BaseSerializer):
183    def serialize(self):
184        # Serialize functools.partial() arguments
185        func_string, func_imports = serializer_factory(self.value.func).serialize()
186        args_string, args_imports = serializer_factory(self.value.args).serialize()
187        keywords_string, keywords_imports = serializer_factory(
188            self.value.keywords
189        ).serialize()
190        # Add any imports needed by arguments
191        imports = {"import functools", *func_imports, *args_imports, *keywords_imports}
192        return (
193            "functools.{}({}, *{}, **{})".format(
194                self.value.__class__.__name__,
195                func_string,
196                args_string,
197                keywords_string,
198            ),
199            imports,
200        )
201
202
203class IterableSerializer(BaseSerializer):
204    def serialize(self):
205        imports = set()
206        strings = []
207        for item in self.value:
208            item_string, item_imports = serializer_factory(item).serialize()
209            imports.update(item_imports)
210            strings.append(item_string)
211        # When len(strings)==0, the empty iterable should be serialized as
212        # "()", not "(,)" because (,) is invalid Python syntax.
213        value = "(%s)" if len(strings) != 1 else "(%s,)"
214        return value % (", ".join(strings)), imports
215
216
217class ModelFieldSerializer(DeconstructableSerializer):
218    def serialize(self):
219        attr_name, path, args, kwargs = self.value.deconstruct()
220        return self.serialize_deconstructed(path, args, kwargs)
221
222
223class ModelManagerSerializer(DeconstructableSerializer):
224    def serialize(self):
225        as_manager, manager_path, qs_path, args, kwargs = self.value.deconstruct()
226        if as_manager:
227            name, imports = self._serialize_path(qs_path)
228            return "%s.as_manager()" % name, imports
229        else:
230            return self.serialize_deconstructed(manager_path, args, kwargs)
231
232
233class OperationSerializer(BaseSerializer):
234    def serialize(self):
235        from plain.models.migrations.writer import OperationWriter
236
237        string, imports = OperationWriter(self.value, indentation=0).serialize()
238        # Nested operation, trailing comma is handled in upper OperationWriter._write()
239        return string.rstrip(","), imports
240
241
242class PathLikeSerializer(BaseSerializer):
243    def serialize(self):
244        return repr(os.fspath(self.value)), {}
245
246
247class PathSerializer(BaseSerializer):
248    def serialize(self):
249        # Convert concrete paths to pure paths to avoid issues with migrations
250        # generated on one platform being used on a different platform.
251        prefix = "Pure" if isinstance(self.value, pathlib.Path) else ""
252        return f"pathlib.{prefix}{self.value!r}", {"import pathlib"}
253
254
255class RegexSerializer(BaseSerializer):
256    def serialize(self):
257        regex_pattern, pattern_imports = serializer_factory(
258            self.value.pattern
259        ).serialize()
260        # Turn off default implicit flags (e.g. re.U) because regexes with the
261        # same implicit and explicit flags aren't equal.
262        flags = self.value.flags ^ re.compile("").flags
263        regex_flags, flag_imports = serializer_factory(flags).serialize()
264        imports = {"import re", *pattern_imports, *flag_imports}
265        args = [regex_pattern]
266        if flags:
267            args.append(regex_flags)
268        return "re.compile(%s)" % ", ".join(args), imports
269
270
271class SequenceSerializer(BaseSequenceSerializer):
272    def _format(self):
273        return "[%s]"
274
275
276class SetSerializer(BaseSequenceSerializer):
277    def _format(self):
278        # Serialize as a set literal except when value is empty because {}
279        # is an empty dict.
280        return "{%s}" if self.value else "set(%s)"
281
282
283class SettingsReferenceSerializer(BaseSerializer):
284    def serialize(self):
285        return "settings.%s" % self.value.setting_name, {
286            "from plain.runtime import settings"
287        }
288
289
290class TupleSerializer(BaseSequenceSerializer):
291    def _format(self):
292        # When len(value)==0, the empty tuple should be serialized as "()",
293        # not "(,)" because (,) is invalid Python syntax.
294        return "(%s)" if len(self.value) != 1 else "(%s,)"
295
296
297class TypeSerializer(BaseSerializer):
298    def serialize(self):
299        special_cases = [
300            (models.Model, "models.Model", ["from plain import models"]),
301            (types.NoneType, "types.NoneType", ["import types"]),
302        ]
303        for case, string, imports in special_cases:
304            if case is self.value:
305                return string, set(imports)
306        if hasattr(self.value, "__module__"):
307            module = self.value.__module__
308            if module == builtins.__name__:
309                return self.value.__name__, set()
310            else:
311                return f"{module}.{self.value.__qualname__}", {"import %s" % module}
312
313
314class UUIDSerializer(BaseSerializer):
315    def serialize(self):
316        return "uuid.%s" % repr(self.value), {"import uuid"}
317
318
319class Serializer:
320    _registry = {
321        # Some of these are order-dependent.
322        frozenset: FrozensetSerializer,
323        list: SequenceSerializer,
324        set: SetSerializer,
325        tuple: TupleSerializer,
326        dict: DictionarySerializer,
327        models.Choices: ChoicesSerializer,
328        enum.Enum: EnumSerializer,
329        datetime.datetime: DatetimeDatetimeSerializer,
330        (datetime.date, datetime.timedelta, datetime.time): DateTimeSerializer,
331        SettingsReference: SettingsReferenceSerializer,
332        float: FloatSerializer,
333        (bool, int, types.NoneType, bytes, str, range): BaseSimpleSerializer,
334        decimal.Decimal: DecimalSerializer,
335        (functools.partial, functools.partialmethod): FunctoolsPartialSerializer,
336        (
337            types.FunctionType,
338            types.BuiltinFunctionType,
339            types.MethodType,
340        ): FunctionTypeSerializer,
341        collections.abc.Iterable: IterableSerializer,
342        (COMPILED_REGEX_TYPE, RegexObject): RegexSerializer,
343        uuid.UUID: UUIDSerializer,
344        pathlib.PurePath: PathSerializer,
345        os.PathLike: PathLikeSerializer,
346    }
347
348    @classmethod
349    def register(cls, type_, serializer):
350        if not issubclass(serializer, BaseSerializer):
351            raise ValueError(
352                "'%s' must inherit from 'BaseSerializer'." % serializer.__name__
353            )
354        cls._registry[type_] = serializer
355
356    @classmethod
357    def unregister(cls, type_):
358        cls._registry.pop(type_)
359
360
361def serializer_factory(value):
362    if isinstance(value, Promise):
363        value = str(value)
364    elif isinstance(value, LazyObject):
365        # The unwrapped value is returned as the first item of the arguments
366        # tuple.
367        value = value.__reduce__()[1][0]
368
369    if isinstance(value, models.Field):
370        return ModelFieldSerializer(value)
371    if isinstance(value, models.manager.BaseManager):
372        return ModelManagerSerializer(value)
373    if isinstance(value, Operation):
374        return OperationSerializer(value)
375    if isinstance(value, type):
376        return TypeSerializer(value)
377    # Anything that knows how to deconstruct itself.
378    if hasattr(value, "deconstruct"):
379        return DeconstructableSerializer(value)
380    for type_, serializer_cls in Serializer._registry.items():
381        if isinstance(value, type_):
382            return serializer_cls(value)
383    raise ValueError(
384        "Cannot serialize: %r\nThere are some values Plain cannot serialize into "
385        "migration files." % value
386    )