Plain is headed towards 1.0! Subscribe for development updates →

  1import builtins
  2import collections.abc
  3import datetime
  4import decimal
  5import enum
  6import functools
  7import math
  8import os
  9import pathlib
 10import re
 11import types
 12import uuid
 13
 14from plain import models
 15from plain.models.migrations.operations.base import Operation
 16from plain.models.migrations.utils import COMPILED_REGEX_TYPE, RegexObject
 17from plain.runtime import SettingsReference
 18from plain.utils.functional import LazyObject, Promise
 19
 20
 21class BaseSerializer:
 22    def __init__(self, value):
 23        self.value = value
 24
 25    def serialize(self):
 26        raise NotImplementedError(
 27            "Subclasses of BaseSerializer must implement the serialize() method."
 28        )
 29
 30
 31class BaseSequenceSerializer(BaseSerializer):
 32    def _format(self):
 33        raise NotImplementedError(
 34            "Subclasses of BaseSequenceSerializer must implement the _format() method."
 35        )
 36
 37    def serialize(self):
 38        imports = set()
 39        strings = []
 40        for item in self.value:
 41            item_string, item_imports = serializer_factory(item).serialize()
 42            imports.update(item_imports)
 43            strings.append(item_string)
 44        value = self._format()
 45        return value % (", ".join(strings)), imports
 46
 47
 48class BaseSimpleSerializer(BaseSerializer):
 49    def serialize(self):
 50        return repr(self.value), set()
 51
 52
 53class ChoicesSerializer(BaseSerializer):
 54    def serialize(self):
 55        return serializer_factory(self.value.value).serialize()
 56
 57
 58class DateTimeSerializer(BaseSerializer):
 59    """For datetime.*, except datetime.datetime."""
 60
 61    def serialize(self):
 62        return repr(self.value), {"import datetime"}
 63
 64
 65class DatetimeDatetimeSerializer(BaseSerializer):
 66    """For datetime.datetime."""
 67
 68    def serialize(self):
 69        if self.value.tzinfo is not None and self.value.tzinfo != datetime.UTC:
 70            self.value = self.value.astimezone(datetime.UTC)
 71        imports = ["import datetime"]
 72        return repr(self.value), set(imports)
 73
 74
 75class DecimalSerializer(BaseSerializer):
 76    def serialize(self):
 77        return repr(self.value), {"from decimal import Decimal"}
 78
 79
 80class DeconstructableSerializer(BaseSerializer):
 81    @staticmethod
 82    def serialize_deconstructed(path, args, kwargs):
 83        name, imports = DeconstructableSerializer._serialize_path(path)
 84        strings = []
 85        for arg in args:
 86            arg_string, arg_imports = serializer_factory(arg).serialize()
 87            strings.append(arg_string)
 88            imports.update(arg_imports)
 89        for kw, arg in sorted(kwargs.items()):
 90            arg_string, arg_imports = serializer_factory(arg).serialize()
 91            imports.update(arg_imports)
 92            strings.append(f"{kw}={arg_string}")
 93        return "{}({})".format(name, ", ".join(strings)), imports
 94
 95    @staticmethod
 96    def _serialize_path(path):
 97        module, name = path.rsplit(".", 1)
 98        if module == "plain.models":
 99            imports = {"from plain import models"}
100            name = f"models.{name}"
101        else:
102            imports = {f"import {module}"}
103            name = path
104        return name, imports
105
106    def serialize(self):
107        return self.serialize_deconstructed(*self.value.deconstruct())
108
109
110class DictionarySerializer(BaseSerializer):
111    def serialize(self):
112        imports = set()
113        strings = []
114        for k, v in sorted(self.value.items()):
115            k_string, k_imports = serializer_factory(k).serialize()
116            v_string, v_imports = serializer_factory(v).serialize()
117            imports.update(k_imports)
118            imports.update(v_imports)
119            strings.append((k_string, v_string))
120        return "{{{}}}".format(", ".join(f"{k}: {v}" for k, v in strings)), imports
121
122
123class EnumSerializer(BaseSerializer):
124    def serialize(self):
125        enum_class = self.value.__class__
126        module = enum_class.__module__
127        if issubclass(enum_class, enum.Flag):
128            members = list(self.value)
129        else:
130            members = (self.value,)
131        return (
132            " | ".join(
133                [
134                    f"{module}.{enum_class.__qualname__}[{item.name!r}]"
135                    for item in members
136                ]
137            ),
138            {f"import {module}"},
139        )
140
141
142class FloatSerializer(BaseSimpleSerializer):
143    def serialize(self):
144        if math.isnan(self.value) or math.isinf(self.value):
145            return f'float("{self.value}")', set()
146        return super().serialize()
147
148
149class FrozensetSerializer(BaseSequenceSerializer):
150    def _format(self):
151        return "frozenset([%s])"
152
153
154class FunctionTypeSerializer(BaseSerializer):
155    def serialize(self):
156        if getattr(self.value, "__self__", None) and isinstance(
157            self.value.__self__, type
158        ):
159            klass = self.value.__self__
160            module = klass.__module__
161            return f"{module}.{klass.__name__}.{self.value.__name__}", {
162                f"import {module}"
163            }
164        # Further error checking
165        if self.value.__name__ == "<lambda>":
166            raise ValueError("Cannot serialize function: lambda")
167        if self.value.__module__ is None:
168            raise ValueError(f"Cannot serialize function {self.value!r}: No module")
169
170        module_name = self.value.__module__
171
172        if "<" not in self.value.__qualname__:  # Qualname can include <locals>
173            return f"{module_name}.{self.value.__qualname__}", {
174                f"import {self.value.__module__}"
175            }
176
177        raise ValueError(
178            f"Could not find function {self.value.__name__} in {module_name}.\n"
179        )
180
181
182class FunctoolsPartialSerializer(BaseSerializer):
183    def serialize(self):
184        # Serialize functools.partial() arguments
185        func_string, func_imports = serializer_factory(self.value.func).serialize()
186        args_string, args_imports = serializer_factory(self.value.args).serialize()
187        keywords_string, keywords_imports = serializer_factory(
188            self.value.keywords
189        ).serialize()
190        # Add any imports needed by arguments
191        imports = {"import functools", *func_imports, *args_imports, *keywords_imports}
192        return (
193            f"functools.{self.value.__class__.__name__}({func_string}, *{args_string}, **{keywords_string})",
194            imports,
195        )
196
197
198class IterableSerializer(BaseSerializer):
199    def serialize(self):
200        imports = set()
201        strings = []
202        for item in self.value:
203            item_string, item_imports = serializer_factory(item).serialize()
204            imports.update(item_imports)
205            strings.append(item_string)
206        # When len(strings)==0, the empty iterable should be serialized as
207        # "()", not "(,)" because (,) is invalid Python syntax.
208        value = "(%s)" if len(strings) != 1 else "(%s,)"
209        return value % (", ".join(strings)), imports
210
211
212class ModelFieldSerializer(DeconstructableSerializer):
213    def serialize(self):
214        attr_name, path, args, kwargs = self.value.deconstruct()
215        return self.serialize_deconstructed(path, args, kwargs)
216
217
218class OperationSerializer(BaseSerializer):
219    def serialize(self):
220        from plain.models.migrations.writer import OperationWriter
221
222        string, imports = OperationWriter(self.value, indentation=0).serialize()
223        # Nested operation, trailing comma is handled in upper OperationWriter._write()
224        return string.rstrip(","), imports
225
226
227class PathLikeSerializer(BaseSerializer):
228    def serialize(self):
229        return repr(os.fspath(self.value)), {}
230
231
232class PathSerializer(BaseSerializer):
233    def serialize(self):
234        # Convert concrete paths to pure paths to avoid issues with migrations
235        # generated on one platform being used on a different platform.
236        prefix = "Pure" if isinstance(self.value, pathlib.Path) else ""
237        return f"pathlib.{prefix}{self.value!r}", {"import pathlib"}
238
239
240class RegexSerializer(BaseSerializer):
241    def serialize(self):
242        regex_pattern, pattern_imports = serializer_factory(
243            self.value.pattern
244        ).serialize()
245        # Turn off default implicit flags (e.g. re.U) because regexes with the
246        # same implicit and explicit flags aren't equal.
247        flags = self.value.flags ^ re.compile("").flags
248        regex_flags, flag_imports = serializer_factory(flags).serialize()
249        imports = {"import re", *pattern_imports, *flag_imports}
250        args = [regex_pattern]
251        if flags:
252            args.append(regex_flags)
253        return "re.compile({})".format(", ".join(args)), imports
254
255
256class SequenceSerializer(BaseSequenceSerializer):
257    def _format(self):
258        return "[%s]"
259
260
261class SetSerializer(BaseSequenceSerializer):
262    def _format(self):
263        # Serialize as a set literal except when value is empty because {}
264        # is an empty dict.
265        return "{%s}" if self.value else "set(%s)"
266
267
268class SettingsReferenceSerializer(BaseSerializer):
269    def serialize(self):
270        return f"settings.{self.value.setting_name}", {
271            "from plain.runtime import settings"
272        }
273
274
275class TupleSerializer(BaseSequenceSerializer):
276    def _format(self):
277        # When len(value)==0, the empty tuple should be serialized as "()",
278        # not "(,)" because (,) is invalid Python syntax.
279        return "(%s)" if len(self.value) != 1 else "(%s,)"
280
281
282class TypeSerializer(BaseSerializer):
283    def serialize(self):
284        special_cases = [
285            (models.Model, "models.Model", ["from plain import models"]),
286            (types.NoneType, "types.NoneType", ["import types"]),
287        ]
288        for case, string, imports in special_cases:
289            if case is self.value:
290                return string, set(imports)
291        if hasattr(self.value, "__module__"):
292            module = self.value.__module__
293            if module == builtins.__name__:
294                return self.value.__name__, set()
295            else:
296                return f"{module}.{self.value.__qualname__}", {f"import {module}"}
297
298
299class UUIDSerializer(BaseSerializer):
300    def serialize(self):
301        return f"uuid.{repr(self.value)}", {"import uuid"}
302
303
304class Serializer:
305    _registry = {
306        # Some of these are order-dependent.
307        frozenset: FrozensetSerializer,
308        list: SequenceSerializer,
309        set: SetSerializer,
310        tuple: TupleSerializer,
311        dict: DictionarySerializer,
312        models.Choices: ChoicesSerializer,
313        enum.Enum: EnumSerializer,
314        datetime.datetime: DatetimeDatetimeSerializer,
315        (datetime.date, datetime.timedelta, datetime.time): DateTimeSerializer,
316        SettingsReference: SettingsReferenceSerializer,
317        float: FloatSerializer,
318        (bool, int, types.NoneType, bytes, str, range): BaseSimpleSerializer,
319        decimal.Decimal: DecimalSerializer,
320        (functools.partial, functools.partialmethod): FunctoolsPartialSerializer,
321        (
322            types.FunctionType,
323            types.BuiltinFunctionType,
324            types.MethodType,
325        ): FunctionTypeSerializer,
326        collections.abc.Iterable: IterableSerializer,
327        (COMPILED_REGEX_TYPE, RegexObject): RegexSerializer,
328        uuid.UUID: UUIDSerializer,
329        pathlib.PurePath: PathSerializer,
330        os.PathLike: PathLikeSerializer,
331    }
332
333    @classmethod
334    def register(cls, type_, serializer):
335        if not issubclass(serializer, BaseSerializer):
336            raise ValueError(
337                f"'{serializer.__name__}' must inherit from 'BaseSerializer'."
338            )
339        cls._registry[type_] = serializer
340
341
342def serializer_factory(value):
343    if isinstance(value, Promise):
344        value = str(value)
345    elif isinstance(value, LazyObject):
346        # The unwrapped value is returned as the first item of the arguments
347        # tuple.
348        value = value.__reduce__()[1][0]
349
350    if isinstance(value, models.Field):
351        return ModelFieldSerializer(value)
352    if isinstance(value, Operation):
353        return OperationSerializer(value)
354    if isinstance(value, type):
355        return TypeSerializer(value)
356    # Anything that knows how to deconstruct itself.
357    if hasattr(value, "deconstruct"):
358        return DeconstructableSerializer(value)
359    for type_, serializer_cls in Serializer._registry.items():
360        if isinstance(value, type_):
361            return serializer_cls(value)
362    raise ValueError(
363        f"Cannot serialize: {value!r}\nThere are some values Plain cannot serialize into "
364        "migration files."
365    )