1from __future__ import annotations
2
3import builtins
4import collections.abc
5import datetime
6import decimal
7import enum
8import functools
9import math
10import os
11import pathlib
12import re
13import types
14import uuid
15from typing import Any
16
17from plain.models.base import Model
18from plain.models.enums import Choices
19from plain.models.fields import Field
20from plain.models.migrations.operations.base import Operation
21from plain.models.migrations.utils import COMPILED_REGEX_TYPE, RegexObject
22from plain.runtime import SettingsReference
23from plain.utils.functional import LazyObject, Promise
24
25
26class BaseSerializer:
27 def __init__(self, value: Any) -> None:
28 self.value = value
29
30 def serialize(self) -> tuple[str, set[str]]:
31 raise NotImplementedError(
32 "subclasses of BaseSerializer must provide a serialize() method"
33 )
34
35
36class BaseSequenceSerializer(BaseSerializer):
37 def _format(self) -> str:
38 raise NotImplementedError(
39 "subclasses of BaseSequenceSerializer must provide a _format() method"
40 )
41
42 def serialize(self) -> tuple[str, set[str]]:
43 imports: set[str] = set()
44 strings = []
45 for item in self.value:
46 item_string, item_imports = serializer_factory(item).serialize()
47 imports.update(item_imports)
48 strings.append(item_string)
49 value = self._format()
50 return value % (", ".join(strings)), imports
51
52
53class BaseSimpleSerializer(BaseSerializer):
54 def serialize(self) -> tuple[str, set[str]]:
55 return repr(self.value), set()
56
57
58class ChoicesSerializer(BaseSerializer):
59 def serialize(self) -> tuple[str, set[str]]:
60 return serializer_factory(self.value.value).serialize()
61
62
63class DateTimeSerializer(BaseSerializer):
64 """For datetime.*, except datetime.datetime."""
65
66 def serialize(self) -> tuple[str, set[str]]:
67 return repr(self.value), {"import datetime"}
68
69
70class DatetimeDatetimeSerializer(BaseSerializer):
71 """For datetime.datetime."""
72
73 def serialize(self) -> tuple[str, set[str]]:
74 if self.value.tzinfo is not None and self.value.tzinfo != datetime.UTC:
75 self.value = self.value.astimezone(datetime.UTC)
76 imports = ["import datetime"]
77 return repr(self.value), set(imports)
78
79
80class DecimalSerializer(BaseSerializer):
81 def serialize(self) -> tuple[str, set[str]]:
82 return repr(self.value), {"from decimal import Decimal"}
83
84
85class DeconstructableSerializer(BaseSerializer):
86 @staticmethod
87 def serialize_deconstructed(
88 path: str, args: tuple[Any, ...], kwargs: dict[str, Any]
89 ) -> tuple[str, set[str]]:
90 name, imports = DeconstructableSerializer._serialize_path(path)
91 strings = []
92 for arg in args:
93 arg_string, arg_imports = serializer_factory(arg).serialize()
94 strings.append(arg_string)
95 imports.update(arg_imports)
96 for kw, arg in sorted(kwargs.items()):
97 arg_string, arg_imports = serializer_factory(arg).serialize()
98 imports.update(arg_imports)
99 strings.append(f"{kw}={arg_string}")
100 return "{}({})".format(name, ", ".join(strings)), imports
101
102 @staticmethod
103 def _serialize_path(path: str) -> tuple[str, set[str]]:
104 module, name = path.rsplit(".", 1)
105 if module == "plain.models":
106 imports: set[str] = {"from plain import models"}
107 name = f"models.{name}"
108 else:
109 imports = {f"import {module}"}
110 name = path
111 return name, imports
112
113 def serialize(self) -> tuple[str, set[str]]:
114 return self.serialize_deconstructed(*self.value.deconstruct())
115
116
117class DictionarySerializer(BaseSerializer):
118 def serialize(self) -> tuple[str, set[str]]:
119 imports: set[str] = set()
120 strings = []
121 for k, v in sorted(self.value.items()):
122 k_string, k_imports = serializer_factory(k).serialize()
123 v_string, v_imports = serializer_factory(v).serialize()
124 imports.update(k_imports)
125 imports.update(v_imports)
126 strings.append((k_string, v_string))
127 return "{{{}}}".format(", ".join(f"{k}: {v}" for k, v in strings)), imports
128
129
130class EnumSerializer(BaseSerializer):
131 def serialize(self) -> tuple[str, set[str]]:
132 enum_class = self.value.__class__
133 module = enum_class.__module__
134 if issubclass(enum_class, enum.Flag):
135 members = list(self.value)
136 else:
137 members = (self.value,)
138 return (
139 " | ".join(
140 [
141 f"{module}.{enum_class.__qualname__}[{item.name!r}]"
142 for item in members
143 ]
144 ),
145 {f"import {module}"},
146 )
147
148
149class FloatSerializer(BaseSimpleSerializer):
150 def serialize(self) -> tuple[str, set[str]]:
151 if math.isnan(self.value) or math.isinf(self.value):
152 return f'float("{self.value}")', set()
153 return super().serialize()
154
155
156class FrozensetSerializer(BaseSequenceSerializer):
157 def _format(self) -> str:
158 return "frozenset([%s])"
159
160
161class FunctionTypeSerializer(BaseSerializer):
162 def serialize(self) -> tuple[str, set[str]]:
163 if getattr(self.value, "__self__", None) and isinstance(
164 self.value.__self__, type
165 ):
166 klass = self.value.__self__
167 module = klass.__module__
168 return f"{module}.{klass.__name__}.{self.value.__name__}", {
169 f"import {module}"
170 }
171 # Further error checking
172 if self.value.__name__ == "<lambda>":
173 raise ValueError("Cannot serialize function: lambda")
174 if self.value.__module__ is None:
175 raise ValueError(f"Cannot serialize function {self.value!r}: No module")
176
177 module_name = self.value.__module__
178
179 if "<" not in self.value.__qualname__: # Qualname can include <locals>
180 return f"{module_name}.{self.value.__qualname__}", {
181 f"import {self.value.__module__}"
182 }
183
184 raise ValueError(
185 f"Could not find function {self.value.__name__} in {module_name}.\n"
186 )
187
188
189class FunctoolsPartialSerializer(BaseSerializer):
190 def serialize(self) -> tuple[str, set[str]]:
191 # Serialize functools.partial() arguments
192 func_string, func_imports = serializer_factory(self.value.func).serialize()
193 args_string, args_imports = serializer_factory(self.value.args).serialize()
194 keywords_string, keywords_imports = serializer_factory(
195 self.value.keywords
196 ).serialize()
197 # Add any imports needed by arguments
198 imports: set[str] = {
199 "import functools",
200 *func_imports,
201 *args_imports,
202 *keywords_imports,
203 }
204 return (
205 f"functools.{self.value.__class__.__name__}({func_string}, *{args_string}, **{keywords_string})",
206 imports,
207 )
208
209
210class IterableSerializer(BaseSerializer):
211 def serialize(self) -> tuple[str, set[str]]:
212 imports: set[str] = set()
213 strings = []
214 for item in self.value:
215 item_string, item_imports = serializer_factory(item).serialize()
216 imports.update(item_imports)
217 strings.append(item_string)
218 # When len(strings)==0, the empty iterable should be serialized as
219 # "()", not "(,)" because (,) is invalid Python syntax.
220 value = "(%s)" if len(strings) != 1 else "(%s,)"
221 return value % (", ".join(strings)), imports
222
223
224class ModelFieldSerializer(DeconstructableSerializer):
225 def serialize(self) -> tuple[str, set[str]]:
226 attr_name, path, args, kwargs = self.value.deconstruct()
227 return self.serialize_deconstructed(path, args, kwargs)
228
229
230class OperationSerializer(BaseSerializer):
231 def serialize(self) -> tuple[str, set[str]]:
232 from plain.models.migrations.writer import OperationWriter
233
234 string, imports = OperationWriter(self.value, indentation=0).serialize()
235 # Nested operation, trailing comma is handled in upper OperationWriter._write()
236 return string.rstrip(","), imports
237
238
239class PathLikeSerializer(BaseSerializer):
240 def serialize(self) -> tuple[str, set[str]]:
241 return repr(os.fspath(self.value)), set()
242
243
244class PathSerializer(BaseSerializer):
245 def serialize(self) -> tuple[str, set[str]]:
246 # Convert concrete paths to pure paths to avoid issues with migrations
247 # generated on one platform being used on a different platform.
248 prefix = "Pure" if isinstance(self.value, pathlib.Path) else ""
249 return f"pathlib.{prefix}{self.value!r}", {"import pathlib"}
250
251
252class RegexSerializer(BaseSerializer):
253 def serialize(self) -> tuple[str, set[str]]:
254 regex_pattern, pattern_imports = serializer_factory(
255 self.value.pattern
256 ).serialize()
257 # Turn off default implicit flags (e.g. re.U) because regexes with the
258 # same implicit and explicit flags aren't equal.
259 flags = self.value.flags ^ re.compile("").flags
260 regex_flags, flag_imports = serializer_factory(flags).serialize()
261 imports: set[str] = {"import re", *pattern_imports, *flag_imports}
262 args = [regex_pattern]
263 if flags:
264 args.append(regex_flags)
265 return "re.compile({})".format(", ".join(args)), imports
266
267
268class SequenceSerializer(BaseSequenceSerializer):
269 def _format(self) -> str:
270 return "[%s]"
271
272
273class SetSerializer(BaseSequenceSerializer):
274 def _format(self) -> str:
275 # Serialize as a set literal except when value is empty because {}
276 # is an empty dict.
277 return "{%s}" if self.value else "set(%s)"
278
279
280class SettingsReferenceSerializer(BaseSerializer):
281 def serialize(self) -> tuple[str, set[str]]:
282 return f"settings.{self.value.setting_name}", {
283 "from plain.runtime import settings"
284 }
285
286
287class TupleSerializer(BaseSequenceSerializer):
288 def _format(self) -> str:
289 # When len(value)==0, the empty tuple should be serialized as "()",
290 # not "(,)" because (,) is invalid Python syntax.
291 return "(%s)" if len(self.value) != 1 else "(%s,)"
292
293
294class TypeSerializer(BaseSerializer):
295 def serialize(self) -> tuple[str, set[str]]:
296 special_cases = [
297 (Model, "models.Model", ["from plain import models"]),
298 (types.NoneType, "types.NoneType", ["import types"]),
299 ]
300 for case, string, imports in special_cases:
301 if case is self.value:
302 return string, set(imports)
303 if hasattr(self.value, "__module__"):
304 module = self.value.__module__
305 if module == builtins.__name__:
306 return self.value.__name__, set()
307 else:
308 return f"{module}.{self.value.__qualname__}", {f"import {module}"}
309 return "", set()
310
311
312class UUIDSerializer(BaseSerializer):
313 def serialize(self) -> tuple[str, set[str]]:
314 return f"uuid.{repr(self.value)}", {"import uuid"}
315
316
317class Serializer:
318 _registry = {
319 # Some of these are order-dependent.
320 frozenset: FrozensetSerializer,
321 list: SequenceSerializer,
322 set: SetSerializer,
323 tuple: TupleSerializer,
324 dict: DictionarySerializer,
325 Choices: ChoicesSerializer,
326 enum.Enum: EnumSerializer,
327 datetime.datetime: DatetimeDatetimeSerializer,
328 (datetime.date, datetime.timedelta, datetime.time): DateTimeSerializer,
329 SettingsReference: SettingsReferenceSerializer,
330 float: FloatSerializer,
331 (bool, int, types.NoneType, bytes, str, range): BaseSimpleSerializer,
332 decimal.Decimal: DecimalSerializer,
333 (functools.partial, functools.partialmethod): FunctoolsPartialSerializer,
334 (
335 types.FunctionType,
336 types.BuiltinFunctionType,
337 types.MethodType,
338 ): FunctionTypeSerializer,
339 collections.abc.Iterable: IterableSerializer,
340 (COMPILED_REGEX_TYPE, RegexObject): RegexSerializer,
341 uuid.UUID: UUIDSerializer,
342 pathlib.PurePath: PathSerializer,
343 os.PathLike: PathLikeSerializer,
344 }
345
346 @classmethod
347 def register(cls, type_: type[Any], serializer: type[BaseSerializer]) -> None:
348 if not issubclass(serializer, BaseSerializer):
349 raise ValueError(
350 f"'{serializer.__name__}' must inherit from 'BaseSerializer'."
351 )
352 cls._registry[type_] = serializer # type: ignore[assignment]
353
354
355def serializer_factory(value: Any) -> BaseSerializer:
356 if isinstance(value, Promise):
357 value = str(value)
358 elif isinstance(value, LazyObject):
359 # The unwrapped value is returned as the first item of the arguments
360 # tuple.
361 value = value.__reduce__()[1][0]
362
363 if isinstance(value, Field):
364 return ModelFieldSerializer(value)
365 if isinstance(value, Operation):
366 return OperationSerializer(value)
367 if isinstance(value, type):
368 return TypeSerializer(value)
369 # Anything that knows how to deconstruct itself.
370 if hasattr(value, "deconstruct"):
371 return DeconstructableSerializer(value)
372 for type_, serializer_cls in Serializer._registry.items():
373 if isinstance(value, type_):
374 return serializer_cls(value)
375 raise ValueError(
376 f"Cannot serialize: {value!r}\nThere are some values Plain cannot serialize into "
377 "migration files."
378 )