Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import builtins
  4import collections.abc
  5import datetime
  6import decimal
  7import enum
  8import functools
  9import math
 10import os
 11import pathlib
 12import re
 13import types
 14import uuid
 15from abc import ABC, abstractmethod
 16from typing import Any
 17
 18from plain.models.base import Model
 19from plain.models.enums import Choices
 20from plain.models.fields import Field
 21from plain.models.migrations.operations.base import Operation
 22from plain.models.migrations.utils import COMPILED_REGEX_TYPE, RegexObject
 23from plain.runtime import SettingsReference
 24from plain.utils.functional import LazyObject, Promise
 25
 26
 27class BaseSerializer(ABC):
 28    def __init__(self, value: Any) -> None:
 29        self.value = value
 30
 31    @abstractmethod
 32    def serialize(self) -> tuple[str, set[str]]: ...
 33
 34
 35class BaseSequenceSerializer(BaseSerializer, ABC):
 36    @abstractmethod
 37    def _format(self) -> str: ...
 38
 39    def serialize(self) -> tuple[str, set[str]]:
 40        imports: set[str] = set()
 41        strings = []
 42        for item in self.value:
 43            item_string, item_imports = serializer_factory(item).serialize()
 44            imports.update(item_imports)
 45            strings.append(item_string)
 46        value = self._format()
 47        return value % (", ".join(strings)), imports
 48
 49
 50class BaseSimpleSerializer(BaseSerializer):
 51    def serialize(self) -> tuple[str, set[str]]:
 52        return repr(self.value), set()
 53
 54
 55class ChoicesSerializer(BaseSerializer):
 56    def serialize(self) -> tuple[str, set[str]]:
 57        return serializer_factory(self.value.value).serialize()
 58
 59
 60class DateTimeSerializer(BaseSerializer):
 61    """For datetime.*, except datetime.datetime."""
 62
 63    def serialize(self) -> tuple[str, set[str]]:
 64        return repr(self.value), {"import datetime"}
 65
 66
 67class DatetimeDatetimeSerializer(BaseSerializer):
 68    """For datetime.datetime."""
 69
 70    def serialize(self) -> tuple[str, set[str]]:
 71        if self.value.tzinfo is not None and self.value.tzinfo != datetime.UTC:
 72            self.value = self.value.astimezone(datetime.UTC)
 73        imports = ["import datetime"]
 74        return repr(self.value), set(imports)
 75
 76
 77class DecimalSerializer(BaseSerializer):
 78    def serialize(self) -> tuple[str, set[str]]:
 79        return repr(self.value), {"from decimal import Decimal"}
 80
 81
 82class DeconstructableSerializer(BaseSerializer):
 83    @staticmethod
 84    def serialize_deconstructed(
 85        path: str, args: tuple[Any, ...], kwargs: dict[str, Any]
 86    ) -> tuple[str, set[str]]:
 87        name, imports = DeconstructableSerializer._serialize_path(path)
 88        strings = []
 89        for arg in args:
 90            arg_string, arg_imports = serializer_factory(arg).serialize()
 91            strings.append(arg_string)
 92            imports.update(arg_imports)
 93        for kw, arg in sorted(kwargs.items()):
 94            arg_string, arg_imports = serializer_factory(arg).serialize()
 95            imports.update(arg_imports)
 96            strings.append(f"{kw}={arg_string}")
 97        return "{}({})".format(name, ", ".join(strings)), imports
 98
 99    @staticmethod
100    def _serialize_path(path: str) -> tuple[str, set[str]]:
101        module, name = path.rsplit(".", 1)
102        if module == "plain.models":
103            imports: set[str] = {"from plain import models"}
104            name = f"models.{name}"
105        else:
106            imports = {f"import {module}"}
107            name = path
108        return name, imports
109
110    def serialize(self) -> tuple[str, set[str]]:
111        return self.serialize_deconstructed(*self.value.deconstruct())
112
113
114class DictionarySerializer(BaseSerializer):
115    def serialize(self) -> tuple[str, set[str]]:
116        imports: set[str] = set()
117        strings = []
118        for k, v in sorted(self.value.items()):
119            k_string, k_imports = serializer_factory(k).serialize()
120            v_string, v_imports = serializer_factory(v).serialize()
121            imports.update(k_imports)
122            imports.update(v_imports)
123            strings.append((k_string, v_string))
124        return "{{{}}}".format(", ".join(f"{k}: {v}" for k, v in strings)), imports
125
126
127class EnumSerializer(BaseSerializer):
128    def serialize(self) -> tuple[str, set[str]]:
129        enum_class = self.value.__class__
130        module = enum_class.__module__
131        if issubclass(enum_class, enum.Flag):
132            members = list(self.value)
133        else:
134            members = (self.value,)
135        return (
136            " | ".join(
137                [
138                    f"{module}.{enum_class.__qualname__}[{item.name!r}]"
139                    for item in members
140                ]
141            ),
142            {f"import {module}"},
143        )
144
145
146class FloatSerializer(BaseSimpleSerializer):
147    def serialize(self) -> tuple[str, set[str]]:
148        if math.isnan(self.value) or math.isinf(self.value):
149            return f'float("{self.value}")', set()
150        return super().serialize()
151
152
153class FrozensetSerializer(BaseSequenceSerializer):
154    def _format(self) -> str:
155        return "frozenset([%s])"
156
157
158class FunctionTypeSerializer(BaseSerializer):
159    def serialize(self) -> tuple[str, set[str]]:
160        if getattr(self.value, "__self__", None) and isinstance(
161            self.value.__self__, type
162        ):
163            klass = self.value.__self__
164            module = klass.__module__
165            return f"{module}.{klass.__name__}.{self.value.__name__}", {
166                f"import {module}"
167            }
168        # Further error checking
169        if self.value.__name__ == "<lambda>":
170            raise ValueError("Cannot serialize function: lambda")
171        if self.value.__module__ is None:
172            raise ValueError(f"Cannot serialize function {self.value!r}: No module")
173
174        module_name = self.value.__module__
175
176        if "<" not in self.value.__qualname__:  # Qualname can include <locals>
177            return f"{module_name}.{self.value.__qualname__}", {
178                f"import {self.value.__module__}"
179            }
180
181        raise ValueError(
182            f"Could not find function {self.value.__name__} in {module_name}.\n"
183        )
184
185
186class FunctoolsPartialSerializer(BaseSerializer):
187    def serialize(self) -> tuple[str, set[str]]:
188        # Serialize functools.partial() arguments
189        func_string, func_imports = serializer_factory(self.value.func).serialize()
190        args_string, args_imports = serializer_factory(self.value.args).serialize()
191        keywords_string, keywords_imports = serializer_factory(
192            self.value.keywords
193        ).serialize()
194        # Add any imports needed by arguments
195        imports: set[str] = {
196            "import functools",
197            *func_imports,
198            *args_imports,
199            *keywords_imports,
200        }
201        return (
202            f"functools.{self.value.__class__.__name__}({func_string}, *{args_string}, **{keywords_string})",
203            imports,
204        )
205
206
207class IterableSerializer(BaseSerializer):
208    def serialize(self) -> tuple[str, set[str]]:
209        imports: set[str] = set()
210        strings = []
211        for item in self.value:
212            item_string, item_imports = serializer_factory(item).serialize()
213            imports.update(item_imports)
214            strings.append(item_string)
215        # When len(strings)==0, the empty iterable should be serialized as
216        # "()", not "(,)" because (,) is invalid Python syntax.
217        value = "(%s)" if len(strings) != 1 else "(%s,)"
218        return value % (", ".join(strings)), imports
219
220
221class ModelFieldSerializer(DeconstructableSerializer):
222    def serialize(self) -> tuple[str, set[str]]:
223        attr_name, path, args, kwargs = self.value.deconstruct()
224        return self.serialize_deconstructed(path, args, kwargs)
225
226
227class OperationSerializer(BaseSerializer):
228    def serialize(self) -> tuple[str, set[str]]:
229        from plain.models.migrations.writer import OperationWriter
230
231        string, imports = OperationWriter(self.value, indentation=0).serialize()
232        # Nested operation, trailing comma is handled in upper OperationWriter._write()
233        return string.rstrip(","), imports
234
235
236class PathLikeSerializer(BaseSerializer):
237    def serialize(self) -> tuple[str, set[str]]:
238        return repr(os.fspath(self.value)), set()
239
240
241class PathSerializer(BaseSerializer):
242    def serialize(self) -> tuple[str, set[str]]:
243        # Convert concrete paths to pure paths to avoid issues with migrations
244        # generated on one platform being used on a different platform.
245        prefix = "Pure" if isinstance(self.value, pathlib.Path) else ""
246        return f"pathlib.{prefix}{self.value!r}", {"import pathlib"}
247
248
249class RegexSerializer(BaseSerializer):
250    def serialize(self) -> tuple[str, set[str]]:
251        regex_pattern, pattern_imports = serializer_factory(
252            self.value.pattern
253        ).serialize()
254        # Turn off default implicit flags (e.g. re.U) because regexes with the
255        # same implicit and explicit flags aren't equal.
256        flags = self.value.flags ^ re.compile("").flags
257        regex_flags, flag_imports = serializer_factory(flags).serialize()
258        imports: set[str] = {"import re", *pattern_imports, *flag_imports}
259        args = [regex_pattern]
260        if flags:
261            args.append(regex_flags)
262        return "re.compile({})".format(", ".join(args)), imports
263
264
265class SequenceSerializer(BaseSequenceSerializer):
266    def _format(self) -> str:
267        return "[%s]"
268
269
270class SetSerializer(BaseSequenceSerializer):
271    def _format(self) -> str:
272        # Serialize as a set literal except when value is empty because {}
273        # is an empty dict.
274        return "{%s}" if self.value else "set(%s)"
275
276
277class SettingsReferenceSerializer(BaseSerializer):
278    def serialize(self) -> tuple[str, set[str]]:
279        return f"settings.{self.value.setting_name}", {
280            "from plain.runtime import settings"
281        }
282
283
284class TupleSerializer(BaseSequenceSerializer):
285    def _format(self) -> str:
286        # When len(value)==0, the empty tuple should be serialized as "()",
287        # not "(,)" because (,) is invalid Python syntax.
288        return "(%s)" if len(self.value) != 1 else "(%s,)"
289
290
291class TypeSerializer(BaseSerializer):
292    def serialize(self) -> tuple[str, set[str]]:
293        special_cases = [
294            (Model, "models.Model", ["from plain import models"]),
295            (types.NoneType, "types.NoneType", ["import types"]),
296        ]
297        for case, string, imports in special_cases:
298            if case is self.value:
299                return string, set(imports)
300        if hasattr(self.value, "__module__"):
301            module = self.value.__module__
302            if module == builtins.__name__:
303                return self.value.__name__, set()
304            else:
305                return f"{module}.{self.value.__qualname__}", {f"import {module}"}
306        return "", set()
307
308
309class UUIDSerializer(BaseSerializer):
310    def serialize(self) -> tuple[str, set[str]]:
311        return f"uuid.{repr(self.value)}", {"import uuid"}
312
313
314class Serializer:
315    _registry = {
316        # Some of these are order-dependent.
317        frozenset: FrozensetSerializer,
318        list: SequenceSerializer,
319        set: SetSerializer,
320        tuple: TupleSerializer,
321        dict: DictionarySerializer,
322        Choices: ChoicesSerializer,
323        enum.Enum: EnumSerializer,
324        datetime.datetime: DatetimeDatetimeSerializer,
325        (datetime.date, datetime.timedelta, datetime.time): DateTimeSerializer,
326        SettingsReference: SettingsReferenceSerializer,
327        float: FloatSerializer,
328        (bool, int, types.NoneType, bytes, str, range): BaseSimpleSerializer,
329        decimal.Decimal: DecimalSerializer,
330        (functools.partial, functools.partialmethod): FunctoolsPartialSerializer,
331        (
332            types.FunctionType,
333            types.BuiltinFunctionType,
334            types.MethodType,
335        ): FunctionTypeSerializer,
336        collections.abc.Iterable: IterableSerializer,
337        (COMPILED_REGEX_TYPE, RegexObject): RegexSerializer,
338        uuid.UUID: UUIDSerializer,
339        pathlib.PurePath: PathSerializer,
340        os.PathLike: PathLikeSerializer,
341    }
342
343    @classmethod
344    def register(cls, type_: type[Any], serializer: type[BaseSerializer]) -> None:
345        if not issubclass(serializer, BaseSerializer):
346            raise ValueError(
347                f"'{serializer.__name__}' must inherit from 'BaseSerializer'."
348            )
349        cls._registry[type_] = serializer
350
351
352def serializer_factory(value: Any) -> BaseSerializer:
353    if isinstance(value, Promise):
354        value = str(value)
355    elif isinstance(value, LazyObject):
356        # The unwrapped value is returned as the first item of the arguments
357        # tuple.
358        value = value.__reduce__()[1][0]
359
360    if isinstance(value, Field):
361        return ModelFieldSerializer(value)
362    if isinstance(value, Operation):
363        return OperationSerializer(value)
364    if isinstance(value, type):
365        return TypeSerializer(value)
366    # Anything that knows how to deconstruct itself.
367    if hasattr(value, "deconstruct"):
368        return DeconstructableSerializer(value)
369    for type_, serializer_cls in Serializer._registry.items():
370        if isinstance(value, type_):
371            return serializer_cls(value)
372    raise ValueError(
373        f"Cannot serialize: {value!r}\nThere are some values Plain cannot serialize into "
374        "migration files."
375    )