1from __future__ import annotations
  2
  3import builtins
  4import collections.abc
  5import datetime
  6import decimal
  7import enum
  8import functools
  9import math
 10import os
 11import pathlib
 12import re
 13import types
 14import uuid
 15from typing import Any
 16
 17from plain.postgres.base import Model
 18from plain.postgres.enums import Choices
 19from plain.postgres.fields import Field
 20from plain.postgres.migrations.operations.base import Operation
 21from plain.postgres.migrations.utils import COMPILED_REGEX_TYPE, RegexObject
 22from plain.runtime import SettingsReference
 23from plain.utils.functional import LazyObject, Promise
 24
 25
 26class BaseSerializer:
 27    def __init__(self, value: Any) -> None:
 28        self.value = value
 29
 30    def serialize(self) -> tuple[str, set[str]]:
 31        raise NotImplementedError(
 32            "subclasses of BaseSerializer must provide a serialize() method"
 33        )
 34
 35
 36class BaseSequenceSerializer(BaseSerializer):
 37    def _format(self) -> str:
 38        raise NotImplementedError(
 39            "subclasses of BaseSequenceSerializer must provide a _format() method"
 40        )
 41
 42    def serialize(self) -> tuple[str, set[str]]:
 43        imports: set[str] = set()
 44        strings = []
 45        for item in self.value:
 46            item_string, item_imports = serializer_factory(item).serialize()
 47            imports.update(item_imports)
 48            strings.append(item_string)
 49        value = self._format()
 50        return value % (", ".join(strings)), imports
 51
 52
 53class BaseSimpleSerializer(BaseSerializer):
 54    def serialize(self) -> tuple[str, set[str]]:
 55        return repr(self.value), set()
 56
 57
 58class ChoicesSerializer(BaseSerializer):
 59    def serialize(self) -> tuple[str, set[str]]:
 60        return serializer_factory(self.value.value).serialize()
 61
 62
 63class DateTimeSerializer(BaseSerializer):
 64    """For datetime.*, except datetime.datetime."""
 65
 66    def serialize(self) -> tuple[str, set[str]]:
 67        return repr(self.value), {"import datetime"}
 68
 69
 70class DatetimeDatetimeSerializer(BaseSerializer):
 71    """For datetime.datetime."""
 72
 73    def serialize(self) -> tuple[str, set[str]]:
 74        if self.value.tzinfo is not None and self.value.tzinfo != datetime.UTC:
 75            self.value = self.value.astimezone(datetime.UTC)
 76        imports = ["import datetime"]
 77        return repr(self.value), set(imports)
 78
 79
 80class DecimalSerializer(BaseSerializer):
 81    def serialize(self) -> tuple[str, set[str]]:
 82        return repr(self.value), {"from decimal import Decimal"}
 83
 84
 85class DeconstructableSerializer(BaseSerializer):
 86    @staticmethod
 87    def serialize_deconstructed(
 88        path: str, args: tuple[Any, ...], kwargs: dict[str, Any]
 89    ) -> tuple[str, set[str]]:
 90        name, imports = DeconstructableSerializer._serialize_path(path)
 91        strings = []
 92        for arg in args:
 93            arg_string, arg_imports = serializer_factory(arg).serialize()
 94            strings.append(arg_string)
 95            imports.update(arg_imports)
 96        for kw, arg in sorted(kwargs.items()):
 97            arg_string, arg_imports = serializer_factory(arg).serialize()
 98            imports.update(arg_imports)
 99            strings.append(f"{kw}={arg_string}")
100        return "{}({})".format(name, ", ".join(strings)), imports
101
102    @staticmethod
103    def _serialize_path(path: str) -> tuple[str, set[str]]:
104        module, name = path.rsplit(".", 1)
105        if module == "plain.postgres":
106            imports: set[str] = {"from plain import postgres"}
107            name = f"postgres.{name}"
108        else:
109            imports = {f"import {module}"}
110            name = path
111        return name, imports
112
113    def serialize(self) -> tuple[str, set[str]]:
114        return self.serialize_deconstructed(*self.value.deconstruct())
115
116
117class DictionarySerializer(BaseSerializer):
118    def serialize(self) -> tuple[str, set[str]]:
119        imports: set[str] = set()
120        strings = []
121        for k, v in sorted(self.value.items()):
122            k_string, k_imports = serializer_factory(k).serialize()
123            v_string, v_imports = serializer_factory(v).serialize()
124            imports.update(k_imports)
125            imports.update(v_imports)
126            strings.append((k_string, v_string))
127        return "{{{}}}".format(", ".join(f"{k}: {v}" for k, v in strings)), imports
128
129
130class EnumSerializer(BaseSerializer):
131    def serialize(self) -> tuple[str, set[str]]:
132        enum_class = self.value.__class__
133        module = enum_class.__module__
134        if issubclass(enum_class, enum.Flag):
135            members = list(self.value)
136        else:
137            members = (self.value,)
138        return (
139            " | ".join(
140                [
141                    f"{module}.{enum_class.__qualname__}[{item.name!r}]"
142                    for item in members
143                ]
144            ),
145            {f"import {module}"},
146        )
147
148
149class FloatSerializer(BaseSimpleSerializer):
150    def serialize(self) -> tuple[str, set[str]]:
151        if math.isnan(self.value) or math.isinf(self.value):
152            return f'float("{self.value}")', set()
153        return super().serialize()
154
155
156class FrozensetSerializer(BaseSequenceSerializer):
157    def _format(self) -> str:
158        return "frozenset([%s])"
159
160
161class FunctionTypeSerializer(BaseSerializer):
162    def serialize(self) -> tuple[str, set[str]]:
163        if getattr(self.value, "__self__", None) and isinstance(
164            self.value.__self__, type
165        ):
166            klass = self.value.__self__
167            module = klass.__module__
168            return f"{module}.{klass.__name__}.{self.value.__name__}", {
169                f"import {module}"
170            }
171        # Further error checking
172        if self.value.__name__ == "<lambda>":
173            raise ValueError("Cannot serialize function: lambda")
174        if self.value.__module__ is None:
175            raise ValueError(f"Cannot serialize function {self.value!r}: No module")
176
177        module_name = self.value.__module__
178
179        if "<" not in self.value.__qualname__:  # Qualname can include <locals>
180            return f"{module_name}.{self.value.__qualname__}", {
181                f"import {self.value.__module__}"
182            }
183
184        raise ValueError(
185            f"Could not find function {self.value.__name__} in {module_name}.\n"
186        )
187
188
189class FunctoolsPartialSerializer(BaseSerializer):
190    def serialize(self) -> tuple[str, set[str]]:
191        # Serialize functools.partial() arguments
192        func_string, func_imports = serializer_factory(self.value.func).serialize()
193        args_string, args_imports = serializer_factory(self.value.args).serialize()
194        keywords_string, keywords_imports = serializer_factory(
195            self.value.keywords
196        ).serialize()
197        # Add any imports needed by arguments
198        imports: set[str] = {
199            "import functools",
200            *func_imports,
201            *args_imports,
202            *keywords_imports,
203        }
204        return (
205            f"functools.{self.value.__class__.__name__}({func_string}, *{args_string}, **{keywords_string})",
206            imports,
207        )
208
209
210class IterableSerializer(BaseSerializer):
211    def serialize(self) -> tuple[str, set[str]]:
212        imports: set[str] = set()
213        strings = []
214        for item in self.value:
215            item_string, item_imports = serializer_factory(item).serialize()
216            imports.update(item_imports)
217            strings.append(item_string)
218        # When len(strings)==0, the empty iterable should be serialized as
219        # "()", not "(,)" because (,) is invalid Python syntax.
220        value = "(%s)" if len(strings) != 1 else "(%s,)"
221        return value % (", ".join(strings)), imports
222
223
224class ModelFieldSerializer(DeconstructableSerializer):
225    def serialize(self) -> tuple[str, set[str]]:
226        attr_name, path, args, kwargs = self.value.deconstruct()
227        return self.serialize_deconstructed(path, args, kwargs)
228
229
230class OperationSerializer(BaseSerializer):
231    def serialize(self) -> tuple[str, set[str]]:
232        from plain.postgres.migrations.writer import OperationWriter
233
234        string, imports = OperationWriter(self.value, indentation=0).serialize()
235        # Nested operation, trailing comma is handled in upper OperationWriter._write()
236        return string.rstrip(","), imports
237
238
239class PathLikeSerializer(BaseSerializer):
240    def serialize(self) -> tuple[str, set[str]]:
241        return repr(os.fspath(self.value)), set()
242
243
244class PathSerializer(BaseSerializer):
245    def serialize(self) -> tuple[str, set[str]]:
246        # Convert concrete paths to pure paths to avoid issues with migrations
247        # generated on one platform being used on a different platform.
248        prefix = "Pure" if isinstance(self.value, pathlib.Path) else ""
249        return f"pathlib.{prefix}{self.value!r}", {"import pathlib"}
250
251
252class RegexSerializer(BaseSerializer):
253    def serialize(self) -> tuple[str, set[str]]:
254        regex_pattern, pattern_imports = serializer_factory(
255            self.value.pattern
256        ).serialize()
257        # Turn off default implicit flags (e.g. re.U) because regexes with the
258        # same implicit and explicit flags aren't equal.
259        flags = self.value.flags ^ re.compile("").flags
260        regex_flags, flag_imports = serializer_factory(flags).serialize()
261        imports: set[str] = {"import re", *pattern_imports, *flag_imports}
262        args = [regex_pattern]
263        if flags:
264            args.append(regex_flags)
265        return "re.compile({})".format(", ".join(args)), imports
266
267
268class SequenceSerializer(BaseSequenceSerializer):
269    def _format(self) -> str:
270        return "[%s]"
271
272
273class SetSerializer(BaseSequenceSerializer):
274    def _format(self) -> str:
275        # Serialize as a set literal except when value is empty because {}
276        # is an empty dict.
277        return "{%s}" if self.value else "set(%s)"
278
279
280class SettingsReferenceSerializer(BaseSerializer):
281    def serialize(self) -> tuple[str, set[str]]:
282        return f"settings.{self.value.setting_name}", {
283            "from plain.runtime import settings"
284        }
285
286
287class TupleSerializer(BaseSequenceSerializer):
288    def _format(self) -> str:
289        # When len(value)==0, the empty tuple should be serialized as "()",
290        # not "(,)" because (,) is invalid Python syntax.
291        return "(%s)" if len(self.value) != 1 else "(%s,)"
292
293
294class TypeSerializer(BaseSerializer):
295    def serialize(self) -> tuple[str, set[str]]:
296        special_cases = [
297            (Model, "postgres.Model", ["from plain import postgres"]),
298            (types.NoneType, "types.NoneType", ["import types"]),
299        ]
300        for case, string, imports in special_cases:
301            if case is self.value:
302                return string, set(imports)
303        if hasattr(self.value, "__module__"):
304            module = self.value.__module__
305            if module == builtins.__name__:
306                return self.value.__name__, set()
307            else:
308                return f"{module}.{self.value.__qualname__}", {f"import {module}"}
309        return "", set()
310
311
312class UUIDSerializer(BaseSerializer):
313    def serialize(self) -> tuple[str, set[str]]:
314        return f"uuid.{repr(self.value)}", {"import uuid"}
315
316
317class Serializer:
318    _registry = {
319        # Some of these are order-dependent.
320        frozenset: FrozensetSerializer,
321        list: SequenceSerializer,
322        set: SetSerializer,
323        tuple: TupleSerializer,
324        dict: DictionarySerializer,
325        Choices: ChoicesSerializer,
326        enum.Enum: EnumSerializer,
327        datetime.datetime: DatetimeDatetimeSerializer,
328        (datetime.date, datetime.timedelta, datetime.time): DateTimeSerializer,
329        SettingsReference: SettingsReferenceSerializer,
330        float: FloatSerializer,
331        (bool, int, types.NoneType, bytes, str, range): BaseSimpleSerializer,
332        decimal.Decimal: DecimalSerializer,
333        (functools.partial, functools.partialmethod): FunctoolsPartialSerializer,
334        (
335            types.FunctionType,
336            types.BuiltinFunctionType,
337            types.MethodType,
338        ): FunctionTypeSerializer,
339        collections.abc.Iterable: IterableSerializer,
340        (COMPILED_REGEX_TYPE, RegexObject): RegexSerializer,
341        uuid.UUID: UUIDSerializer,
342        pathlib.PurePath: PathSerializer,
343        os.PathLike: PathLikeSerializer,
344    }
345
346    @classmethod
347    def register(cls, type_: type[Any], serializer: type[BaseSerializer]) -> None:
348        if not issubclass(serializer, BaseSerializer):
349            raise ValueError(
350                f"'{serializer.__name__}' must inherit from 'BaseSerializer'."
351            )
352        cls._registry[type_] = serializer  # type: ignore[assignment]
353
354
355def serializer_factory(value: Any) -> BaseSerializer:
356    if isinstance(value, Promise):
357        value = str(value)
358    elif isinstance(value, LazyObject):
359        # The unwrapped value is returned as the first item of the arguments
360        # tuple.
361        value = value.__reduce__()[1][0]
362
363    if isinstance(value, Field):
364        return ModelFieldSerializer(value)
365    if isinstance(value, Operation):
366        return OperationSerializer(value)
367    if isinstance(value, type):
368        return TypeSerializer(value)
369    # Anything that knows how to deconstruct itself.
370    if hasattr(value, "deconstruct"):
371        return DeconstructableSerializer(value)
372    for type_, serializer_cls in Serializer._registry.items():
373        if isinstance(value, type_):
374            return serializer_cls(value)
375    raise ValueError(
376        f"Cannot serialize: {value!r}\nThere are some values Plain cannot serialize into "
377        "migration files."
378    )