1from __future__ import annotations
2
3import builtins
4import collections.abc
5import datetime
6import decimal
7import enum
8import functools
9import math
10import os
11import pathlib
12import re
13import types
14import uuid
15from abc import ABC, abstractmethod
16from typing import Any
17
18from plain.models.base import Model
19from plain.models.enums import Choices
20from plain.models.fields import Field
21from plain.models.migrations.operations.base import Operation
22from plain.models.migrations.utils import COMPILED_REGEX_TYPE, RegexObject
23from plain.runtime import SettingsReference
24from plain.utils.functional import LazyObject, Promise
25
26
27class BaseSerializer(ABC):
28 def __init__(self, value: Any) -> None:
29 self.value = value
30
31 @abstractmethod
32 def serialize(self) -> tuple[str, set[str]]: ...
33
34
35class BaseSequenceSerializer(BaseSerializer, ABC):
36 @abstractmethod
37 def _format(self) -> str: ...
38
39 def serialize(self) -> tuple[str, set[str]]:
40 imports: set[str] = set()
41 strings = []
42 for item in self.value:
43 item_string, item_imports = serializer_factory(item).serialize()
44 imports.update(item_imports)
45 strings.append(item_string)
46 value = self._format()
47 return value % (", ".join(strings)), imports
48
49
50class BaseSimpleSerializer(BaseSerializer):
51 def serialize(self) -> tuple[str, set[str]]:
52 return repr(self.value), set()
53
54
55class ChoicesSerializer(BaseSerializer):
56 def serialize(self) -> tuple[str, set[str]]:
57 return serializer_factory(self.value.value).serialize()
58
59
60class DateTimeSerializer(BaseSerializer):
61 """For datetime.*, except datetime.datetime."""
62
63 def serialize(self) -> tuple[str, set[str]]:
64 return repr(self.value), {"import datetime"}
65
66
67class DatetimeDatetimeSerializer(BaseSerializer):
68 """For datetime.datetime."""
69
70 def serialize(self) -> tuple[str, set[str]]:
71 if self.value.tzinfo is not None and self.value.tzinfo != datetime.UTC:
72 self.value = self.value.astimezone(datetime.UTC)
73 imports = ["import datetime"]
74 return repr(self.value), set(imports)
75
76
77class DecimalSerializer(BaseSerializer):
78 def serialize(self) -> tuple[str, set[str]]:
79 return repr(self.value), {"from decimal import Decimal"}
80
81
82class DeconstructableSerializer(BaseSerializer):
83 @staticmethod
84 def serialize_deconstructed(
85 path: str, args: tuple[Any, ...], kwargs: dict[str, Any]
86 ) -> tuple[str, set[str]]:
87 name, imports = DeconstructableSerializer._serialize_path(path)
88 strings = []
89 for arg in args:
90 arg_string, arg_imports = serializer_factory(arg).serialize()
91 strings.append(arg_string)
92 imports.update(arg_imports)
93 for kw, arg in sorted(kwargs.items()):
94 arg_string, arg_imports = serializer_factory(arg).serialize()
95 imports.update(arg_imports)
96 strings.append(f"{kw}={arg_string}")
97 return "{}({})".format(name, ", ".join(strings)), imports
98
99 @staticmethod
100 def _serialize_path(path: str) -> tuple[str, set[str]]:
101 module, name = path.rsplit(".", 1)
102 if module == "plain.models":
103 imports: set[str] = {"from plain import models"}
104 name = f"models.{name}"
105 else:
106 imports = {f"import {module}"}
107 name = path
108 return name, imports
109
110 def serialize(self) -> tuple[str, set[str]]:
111 return self.serialize_deconstructed(*self.value.deconstruct())
112
113
114class DictionarySerializer(BaseSerializer):
115 def serialize(self) -> tuple[str, set[str]]:
116 imports: set[str] = set()
117 strings = []
118 for k, v in sorted(self.value.items()):
119 k_string, k_imports = serializer_factory(k).serialize()
120 v_string, v_imports = serializer_factory(v).serialize()
121 imports.update(k_imports)
122 imports.update(v_imports)
123 strings.append((k_string, v_string))
124 return "{{{}}}".format(", ".join(f"{k}: {v}" for k, v in strings)), imports
125
126
127class EnumSerializer(BaseSerializer):
128 def serialize(self) -> tuple[str, set[str]]:
129 enum_class = self.value.__class__
130 module = enum_class.__module__
131 if issubclass(enum_class, enum.Flag):
132 members = list(self.value)
133 else:
134 members = (self.value,)
135 return (
136 " | ".join(
137 [
138 f"{module}.{enum_class.__qualname__}[{item.name!r}]"
139 for item in members
140 ]
141 ),
142 {f"import {module}"},
143 )
144
145
146class FloatSerializer(BaseSimpleSerializer):
147 def serialize(self) -> tuple[str, set[str]]:
148 if math.isnan(self.value) or math.isinf(self.value):
149 return f'float("{self.value}")', set()
150 return super().serialize()
151
152
153class FrozensetSerializer(BaseSequenceSerializer):
154 def _format(self) -> str:
155 return "frozenset([%s])"
156
157
158class FunctionTypeSerializer(BaseSerializer):
159 def serialize(self) -> tuple[str, set[str]]:
160 if getattr(self.value, "__self__", None) and isinstance(
161 self.value.__self__, type
162 ):
163 klass = self.value.__self__
164 module = klass.__module__
165 return f"{module}.{klass.__name__}.{self.value.__name__}", {
166 f"import {module}"
167 }
168 # Further error checking
169 if self.value.__name__ == "<lambda>":
170 raise ValueError("Cannot serialize function: lambda")
171 if self.value.__module__ is None:
172 raise ValueError(f"Cannot serialize function {self.value!r}: No module")
173
174 module_name = self.value.__module__
175
176 if "<" not in self.value.__qualname__: # Qualname can include <locals>
177 return f"{module_name}.{self.value.__qualname__}", {
178 f"import {self.value.__module__}"
179 }
180
181 raise ValueError(
182 f"Could not find function {self.value.__name__} in {module_name}.\n"
183 )
184
185
186class FunctoolsPartialSerializer(BaseSerializer):
187 def serialize(self) -> tuple[str, set[str]]:
188 # Serialize functools.partial() arguments
189 func_string, func_imports = serializer_factory(self.value.func).serialize()
190 args_string, args_imports = serializer_factory(self.value.args).serialize()
191 keywords_string, keywords_imports = serializer_factory(
192 self.value.keywords
193 ).serialize()
194 # Add any imports needed by arguments
195 imports: set[str] = {
196 "import functools",
197 *func_imports,
198 *args_imports,
199 *keywords_imports,
200 }
201 return (
202 f"functools.{self.value.__class__.__name__}({func_string}, *{args_string}, **{keywords_string})",
203 imports,
204 )
205
206
207class IterableSerializer(BaseSerializer):
208 def serialize(self) -> tuple[str, set[str]]:
209 imports: set[str] = set()
210 strings = []
211 for item in self.value:
212 item_string, item_imports = serializer_factory(item).serialize()
213 imports.update(item_imports)
214 strings.append(item_string)
215 # When len(strings)==0, the empty iterable should be serialized as
216 # "()", not "(,)" because (,) is invalid Python syntax.
217 value = "(%s)" if len(strings) != 1 else "(%s,)"
218 return value % (", ".join(strings)), imports
219
220
221class ModelFieldSerializer(DeconstructableSerializer):
222 def serialize(self) -> tuple[str, set[str]]:
223 attr_name, path, args, kwargs = self.value.deconstruct()
224 return self.serialize_deconstructed(path, args, kwargs)
225
226
227class OperationSerializer(BaseSerializer):
228 def serialize(self) -> tuple[str, set[str]]:
229 from plain.models.migrations.writer import OperationWriter
230
231 string, imports = OperationWriter(self.value, indentation=0).serialize()
232 # Nested operation, trailing comma is handled in upper OperationWriter._write()
233 return string.rstrip(","), imports
234
235
236class PathLikeSerializer(BaseSerializer):
237 def serialize(self) -> tuple[str, set[str]]:
238 return repr(os.fspath(self.value)), set()
239
240
241class PathSerializer(BaseSerializer):
242 def serialize(self) -> tuple[str, set[str]]:
243 # Convert concrete paths to pure paths to avoid issues with migrations
244 # generated on one platform being used on a different platform.
245 prefix = "Pure" if isinstance(self.value, pathlib.Path) else ""
246 return f"pathlib.{prefix}{self.value!r}", {"import pathlib"}
247
248
249class RegexSerializer(BaseSerializer):
250 def serialize(self) -> tuple[str, set[str]]:
251 regex_pattern, pattern_imports = serializer_factory(
252 self.value.pattern
253 ).serialize()
254 # Turn off default implicit flags (e.g. re.U) because regexes with the
255 # same implicit and explicit flags aren't equal.
256 flags = self.value.flags ^ re.compile("").flags
257 regex_flags, flag_imports = serializer_factory(flags).serialize()
258 imports: set[str] = {"import re", *pattern_imports, *flag_imports}
259 args = [regex_pattern]
260 if flags:
261 args.append(regex_flags)
262 return "re.compile({})".format(", ".join(args)), imports
263
264
265class SequenceSerializer(BaseSequenceSerializer):
266 def _format(self) -> str:
267 return "[%s]"
268
269
270class SetSerializer(BaseSequenceSerializer):
271 def _format(self) -> str:
272 # Serialize as a set literal except when value is empty because {}
273 # is an empty dict.
274 return "{%s}" if self.value else "set(%s)"
275
276
277class SettingsReferenceSerializer(BaseSerializer):
278 def serialize(self) -> tuple[str, set[str]]:
279 return f"settings.{self.value.setting_name}", {
280 "from plain.runtime import settings"
281 }
282
283
284class TupleSerializer(BaseSequenceSerializer):
285 def _format(self) -> str:
286 # When len(value)==0, the empty tuple should be serialized as "()",
287 # not "(,)" because (,) is invalid Python syntax.
288 return "(%s)" if len(self.value) != 1 else "(%s,)"
289
290
291class TypeSerializer(BaseSerializer):
292 def serialize(self) -> tuple[str, set[str]]:
293 special_cases = [
294 (Model, "models.Model", ["from plain import models"]),
295 (types.NoneType, "types.NoneType", ["import types"]),
296 ]
297 for case, string, imports in special_cases:
298 if case is self.value:
299 return string, set(imports)
300 if hasattr(self.value, "__module__"):
301 module = self.value.__module__
302 if module == builtins.__name__:
303 return self.value.__name__, set()
304 else:
305 return f"{module}.{self.value.__qualname__}", {f"import {module}"}
306 return "", set()
307
308
309class UUIDSerializer(BaseSerializer):
310 def serialize(self) -> tuple[str, set[str]]:
311 return f"uuid.{repr(self.value)}", {"import uuid"}
312
313
314class Serializer:
315 _registry = {
316 # Some of these are order-dependent.
317 frozenset: FrozensetSerializer,
318 list: SequenceSerializer,
319 set: SetSerializer,
320 tuple: TupleSerializer,
321 dict: DictionarySerializer,
322 Choices: ChoicesSerializer,
323 enum.Enum: EnumSerializer,
324 datetime.datetime: DatetimeDatetimeSerializer,
325 (datetime.date, datetime.timedelta, datetime.time): DateTimeSerializer,
326 SettingsReference: SettingsReferenceSerializer,
327 float: FloatSerializer,
328 (bool, int, types.NoneType, bytes, str, range): BaseSimpleSerializer,
329 decimal.Decimal: DecimalSerializer,
330 (functools.partial, functools.partialmethod): FunctoolsPartialSerializer,
331 (
332 types.FunctionType,
333 types.BuiltinFunctionType,
334 types.MethodType,
335 ): FunctionTypeSerializer,
336 collections.abc.Iterable: IterableSerializer,
337 (COMPILED_REGEX_TYPE, RegexObject): RegexSerializer,
338 uuid.UUID: UUIDSerializer,
339 pathlib.PurePath: PathSerializer,
340 os.PathLike: PathLikeSerializer,
341 }
342
343 @classmethod
344 def register(cls, type_: type[Any], serializer: type[BaseSerializer]) -> None:
345 if not issubclass(serializer, BaseSerializer):
346 raise ValueError(
347 f"'{serializer.__name__}' must inherit from 'BaseSerializer'."
348 )
349 cls._registry[type_] = serializer
350
351
352def serializer_factory(value: Any) -> BaseSerializer:
353 if isinstance(value, Promise):
354 value = str(value)
355 elif isinstance(value, LazyObject):
356 # The unwrapped value is returned as the first item of the arguments
357 # tuple.
358 value = value.__reduce__()[1][0]
359
360 if isinstance(value, Field):
361 return ModelFieldSerializer(value)
362 if isinstance(value, Operation):
363 return OperationSerializer(value)
364 if isinstance(value, type):
365 return TypeSerializer(value)
366 # Anything that knows how to deconstruct itself.
367 if hasattr(value, "deconstruct"):
368 return DeconstructableSerializer(value)
369 for type_, serializer_cls in Serializer._registry.items():
370 if isinstance(value, type_):
371 return serializer_cls(value)
372 raise ValueError(
373 f"Cannot serialize: {value!r}\nThere are some values Plain cannot serialize into "
374 "migration files."
375 )