Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import builtins
  4import collections.abc
  5import datetime
  6import decimal
  7import enum
  8import functools
  9import math
 10import os
 11import pathlib
 12import re
 13import types
 14import uuid
 15from typing import Any
 16
 17from plain import models
 18from plain.models.migrations.operations.base import Operation
 19from plain.models.migrations.utils import COMPILED_REGEX_TYPE, RegexObject
 20from plain.runtime import SettingsReference
 21from plain.utils.functional import LazyObject, Promise
 22
 23
 24class BaseSerializer:
 25    def __init__(self, value: Any) -> None:
 26        self.value = value
 27
 28    def serialize(self) -> tuple[str, set[str]]:
 29        raise NotImplementedError(
 30            "Subclasses of BaseSerializer must implement the serialize() method."
 31        )
 32
 33
 34class BaseSequenceSerializer(BaseSerializer):
 35    def _format(self) -> str:
 36        raise NotImplementedError(
 37            "Subclasses of BaseSequenceSerializer must implement the _format() method."
 38        )
 39
 40    def serialize(self) -> tuple[str, set[str]]:
 41        imports: set[str] = set()
 42        strings = []
 43        for item in self.value:
 44            item_string, item_imports = serializer_factory(item).serialize()
 45            imports.update(item_imports)
 46            strings.append(item_string)
 47        value = self._format()
 48        return value % (", ".join(strings)), imports
 49
 50
 51class BaseSimpleSerializer(BaseSerializer):
 52    def serialize(self) -> tuple[str, set[str]]:
 53        return repr(self.value), set()
 54
 55
 56class ChoicesSerializer(BaseSerializer):
 57    def serialize(self) -> tuple[str, set[str]]:
 58        return serializer_factory(self.value.value).serialize()
 59
 60
 61class DateTimeSerializer(BaseSerializer):
 62    """For datetime.*, except datetime.datetime."""
 63
 64    def serialize(self) -> tuple[str, set[str]]:
 65        return repr(self.value), {"import datetime"}
 66
 67
 68class DatetimeDatetimeSerializer(BaseSerializer):
 69    """For datetime.datetime."""
 70
 71    def serialize(self) -> tuple[str, set[str]]:
 72        if self.value.tzinfo is not None and self.value.tzinfo != datetime.UTC:
 73            self.value = self.value.astimezone(datetime.UTC)
 74        imports = ["import datetime"]
 75        return repr(self.value), set(imports)
 76
 77
 78class DecimalSerializer(BaseSerializer):
 79    def serialize(self) -> tuple[str, set[str]]:
 80        return repr(self.value), {"from decimal import Decimal"}
 81
 82
 83class DeconstructableSerializer(BaseSerializer):
 84    @staticmethod
 85    def serialize_deconstructed(
 86        path: str, args: tuple[Any, ...], kwargs: dict[str, Any]
 87    ) -> tuple[str, set[str]]:
 88        name, imports = DeconstructableSerializer._serialize_path(path)
 89        strings = []
 90        for arg in args:
 91            arg_string, arg_imports = serializer_factory(arg).serialize()
 92            strings.append(arg_string)
 93            imports.update(arg_imports)
 94        for kw, arg in sorted(kwargs.items()):
 95            arg_string, arg_imports = serializer_factory(arg).serialize()
 96            imports.update(arg_imports)
 97            strings.append(f"{kw}={arg_string}")
 98        return "{}({})".format(name, ", ".join(strings)), imports
 99
100    @staticmethod
101    def _serialize_path(path: str) -> tuple[str, set[str]]:
102        module, name = path.rsplit(".", 1)
103        if module == "plain.models":
104            imports: set[str] = {"from plain import models"}
105            name = f"models.{name}"
106        else:
107            imports = {f"import {module}"}
108            name = path
109        return name, imports
110
111    def serialize(self) -> tuple[str, set[str]]:
112        return self.serialize_deconstructed(*self.value.deconstruct())
113
114
115class DictionarySerializer(BaseSerializer):
116    def serialize(self) -> tuple[str, set[str]]:
117        imports: set[str] = set()
118        strings = []
119        for k, v in sorted(self.value.items()):
120            k_string, k_imports = serializer_factory(k).serialize()
121            v_string, v_imports = serializer_factory(v).serialize()
122            imports.update(k_imports)
123            imports.update(v_imports)
124            strings.append((k_string, v_string))
125        return "{{{}}}".format(", ".join(f"{k}: {v}" for k, v in strings)), imports
126
127
128class EnumSerializer(BaseSerializer):
129    def serialize(self) -> tuple[str, set[str]]:
130        enum_class = self.value.__class__
131        module = enum_class.__module__
132        if issubclass(enum_class, enum.Flag):
133            members = list(self.value)
134        else:
135            members = (self.value,)
136        return (
137            " | ".join(
138                [
139                    f"{module}.{enum_class.__qualname__}[{item.name!r}]"
140                    for item in members
141                ]
142            ),
143            {f"import {module}"},
144        )
145
146
147class FloatSerializer(BaseSimpleSerializer):
148    def serialize(self) -> tuple[str, set[str]]:
149        if math.isnan(self.value) or math.isinf(self.value):
150            return f'float("{self.value}")', set()
151        return super().serialize()
152
153
154class FrozensetSerializer(BaseSequenceSerializer):
155    def _format(self) -> str:
156        return "frozenset([%s])"
157
158
159class FunctionTypeSerializer(BaseSerializer):
160    def serialize(self) -> tuple[str, set[str]]:
161        if getattr(self.value, "__self__", None) and isinstance(
162            self.value.__self__, type
163        ):
164            klass = self.value.__self__
165            module = klass.__module__
166            return f"{module}.{klass.__name__}.{self.value.__name__}", {
167                f"import {module}"
168            }
169        # Further error checking
170        if self.value.__name__ == "<lambda>":
171            raise ValueError("Cannot serialize function: lambda")
172        if self.value.__module__ is None:
173            raise ValueError(f"Cannot serialize function {self.value!r}: No module")
174
175        module_name = self.value.__module__
176
177        if "<" not in self.value.__qualname__:  # Qualname can include <locals>
178            return f"{module_name}.{self.value.__qualname__}", {
179                f"import {self.value.__module__}"
180            }
181
182        raise ValueError(
183            f"Could not find function {self.value.__name__} in {module_name}.\n"
184        )
185
186
187class FunctoolsPartialSerializer(BaseSerializer):
188    def serialize(self) -> tuple[str, set[str]]:
189        # Serialize functools.partial() arguments
190        func_string, func_imports = serializer_factory(self.value.func).serialize()
191        args_string, args_imports = serializer_factory(self.value.args).serialize()
192        keywords_string, keywords_imports = serializer_factory(
193            self.value.keywords
194        ).serialize()
195        # Add any imports needed by arguments
196        imports: set[str] = {
197            "import functools",
198            *func_imports,
199            *args_imports,
200            *keywords_imports,
201        }
202        return (
203            f"functools.{self.value.__class__.__name__}({func_string}, *{args_string}, **{keywords_string})",
204            imports,
205        )
206
207
208class IterableSerializer(BaseSerializer):
209    def serialize(self) -> tuple[str, set[str]]:
210        imports: set[str] = set()
211        strings = []
212        for item in self.value:
213            item_string, item_imports = serializer_factory(item).serialize()
214            imports.update(item_imports)
215            strings.append(item_string)
216        # When len(strings)==0, the empty iterable should be serialized as
217        # "()", not "(,)" because (,) is invalid Python syntax.
218        value = "(%s)" if len(strings) != 1 else "(%s,)"
219        return value % (", ".join(strings)), imports
220
221
222class ModelFieldSerializer(DeconstructableSerializer):
223    def serialize(self) -> tuple[str, set[str]]:
224        attr_name, path, args, kwargs = self.value.deconstruct()
225        return self.serialize_deconstructed(path, args, kwargs)
226
227
228class OperationSerializer(BaseSerializer):
229    def serialize(self) -> tuple[str, set[str]]:
230        from plain.models.migrations.writer import OperationWriter
231
232        string, imports = OperationWriter(self.value, indentation=0).serialize()
233        # Nested operation, trailing comma is handled in upper OperationWriter._write()
234        return string.rstrip(","), imports
235
236
237class PathLikeSerializer(BaseSerializer):
238    def serialize(self) -> tuple[str, set[str]]:
239        return repr(os.fspath(self.value)), set()
240
241
242class PathSerializer(BaseSerializer):
243    def serialize(self) -> tuple[str, set[str]]:
244        # Convert concrete paths to pure paths to avoid issues with migrations
245        # generated on one platform being used on a different platform.
246        prefix = "Pure" if isinstance(self.value, pathlib.Path) else ""
247        return f"pathlib.{prefix}{self.value!r}", {"import pathlib"}
248
249
250class RegexSerializer(BaseSerializer):
251    def serialize(self) -> tuple[str, set[str]]:
252        regex_pattern, pattern_imports = serializer_factory(
253            self.value.pattern
254        ).serialize()
255        # Turn off default implicit flags (e.g. re.U) because regexes with the
256        # same implicit and explicit flags aren't equal.
257        flags = self.value.flags ^ re.compile("").flags
258        regex_flags, flag_imports = serializer_factory(flags).serialize()
259        imports: set[str] = {"import re", *pattern_imports, *flag_imports}
260        args = [regex_pattern]
261        if flags:
262            args.append(regex_flags)
263        return "re.compile({})".format(", ".join(args)), imports
264
265
266class SequenceSerializer(BaseSequenceSerializer):
267    def _format(self) -> str:
268        return "[%s]"
269
270
271class SetSerializer(BaseSequenceSerializer):
272    def _format(self) -> str:
273        # Serialize as a set literal except when value is empty because {}
274        # is an empty dict.
275        return "{%s}" if self.value else "set(%s)"
276
277
278class SettingsReferenceSerializer(BaseSerializer):
279    def serialize(self) -> tuple[str, set[str]]:
280        return f"settings.{self.value.setting_name}", {
281            "from plain.runtime import settings"
282        }
283
284
285class TupleSerializer(BaseSequenceSerializer):
286    def _format(self) -> str:
287        # When len(value)==0, the empty tuple should be serialized as "()",
288        # not "(,)" because (,) is invalid Python syntax.
289        return "(%s)" if len(self.value) != 1 else "(%s,)"
290
291
292class TypeSerializer(BaseSerializer):
293    def serialize(self) -> tuple[str, set[str]]:
294        special_cases = [
295            (models.Model, "models.Model", ["from plain import models"]),
296            (types.NoneType, "types.NoneType", ["import types"]),
297        ]
298        for case, string, imports in special_cases:
299            if case is self.value:
300                return string, set(imports)
301        if hasattr(self.value, "__module__"):
302            module = self.value.__module__
303            if module == builtins.__name__:
304                return self.value.__name__, set()
305            else:
306                return f"{module}.{self.value.__qualname__}", {f"import {module}"}
307        return "", set()
308
309
310class UUIDSerializer(BaseSerializer):
311    def serialize(self) -> tuple[str, set[str]]:
312        return f"uuid.{repr(self.value)}", {"import uuid"}
313
314
315class Serializer:
316    _registry = {
317        # Some of these are order-dependent.
318        frozenset: FrozensetSerializer,
319        list: SequenceSerializer,
320        set: SetSerializer,
321        tuple: TupleSerializer,
322        dict: DictionarySerializer,
323        models.Choices: ChoicesSerializer,
324        enum.Enum: EnumSerializer,
325        datetime.datetime: DatetimeDatetimeSerializer,
326        (datetime.date, datetime.timedelta, datetime.time): DateTimeSerializer,
327        SettingsReference: SettingsReferenceSerializer,
328        float: FloatSerializer,
329        (bool, int, types.NoneType, bytes, str, range): BaseSimpleSerializer,
330        decimal.Decimal: DecimalSerializer,
331        (functools.partial, functools.partialmethod): FunctoolsPartialSerializer,
332        (
333            types.FunctionType,
334            types.BuiltinFunctionType,
335            types.MethodType,
336        ): FunctionTypeSerializer,
337        collections.abc.Iterable: IterableSerializer,
338        (COMPILED_REGEX_TYPE, RegexObject): RegexSerializer,
339        uuid.UUID: UUIDSerializer,
340        pathlib.PurePath: PathSerializer,
341        os.PathLike: PathLikeSerializer,
342    }
343
344    @classmethod
345    def register(cls, type_: type[Any], serializer: type[BaseSerializer]) -> None:
346        if not issubclass(serializer, BaseSerializer):
347            raise ValueError(
348                f"'{serializer.__name__}' must inherit from 'BaseSerializer'."
349            )
350        cls._registry[type_] = serializer
351
352
353def serializer_factory(value: Any) -> BaseSerializer:
354    if isinstance(value, Promise):
355        value = str(value)
356    elif isinstance(value, LazyObject):
357        # The unwrapped value is returned as the first item of the arguments
358        # tuple.
359        value = value.__reduce__()[1][0]
360
361    if isinstance(value, models.Field):
362        return ModelFieldSerializer(value)
363    if isinstance(value, Operation):
364        return OperationSerializer(value)
365    if isinstance(value, type):
366        return TypeSerializer(value)
367    # Anything that knows how to deconstruct itself.
368    if hasattr(value, "deconstruct"):
369        return DeconstructableSerializer(value)
370    for type_, serializer_cls in Serializer._registry.items():
371        if isinstance(value, type_):
372            return serializer_cls(value)
373    raise ValueError(
374        f"Cannot serialize: {value!r}\nThere are some values Plain cannot serialize into "
375        "migration files."
376    )