1from __future__ import annotations
2
3import builtins
4import collections.abc
5import datetime
6import decimal
7import enum
8import functools
9import math
10import os
11import pathlib
12import re
13import types
14import uuid
15from typing import Any
16
17from plain import models
18from plain.models.migrations.operations.base import Operation
19from plain.models.migrations.utils import COMPILED_REGEX_TYPE, RegexObject
20from plain.runtime import SettingsReference
21from plain.utils.functional import LazyObject, Promise
22
23
24class BaseSerializer:
25 def __init__(self, value: Any) -> None:
26 self.value = value
27
28 def serialize(self) -> tuple[str, set[str]]:
29 raise NotImplementedError(
30 "Subclasses of BaseSerializer must implement the serialize() method."
31 )
32
33
34class BaseSequenceSerializer(BaseSerializer):
35 def _format(self) -> str:
36 raise NotImplementedError(
37 "Subclasses of BaseSequenceSerializer must implement the _format() method."
38 )
39
40 def serialize(self) -> tuple[str, set[str]]:
41 imports: set[str] = set()
42 strings = []
43 for item in self.value:
44 item_string, item_imports = serializer_factory(item).serialize()
45 imports.update(item_imports)
46 strings.append(item_string)
47 value = self._format()
48 return value % (", ".join(strings)), imports
49
50
51class BaseSimpleSerializer(BaseSerializer):
52 def serialize(self) -> tuple[str, set[str]]:
53 return repr(self.value), set()
54
55
56class ChoicesSerializer(BaseSerializer):
57 def serialize(self) -> tuple[str, set[str]]:
58 return serializer_factory(self.value.value).serialize()
59
60
61class DateTimeSerializer(BaseSerializer):
62 """For datetime.*, except datetime.datetime."""
63
64 def serialize(self) -> tuple[str, set[str]]:
65 return repr(self.value), {"import datetime"}
66
67
68class DatetimeDatetimeSerializer(BaseSerializer):
69 """For datetime.datetime."""
70
71 def serialize(self) -> tuple[str, set[str]]:
72 if self.value.tzinfo is not None and self.value.tzinfo != datetime.UTC:
73 self.value = self.value.astimezone(datetime.UTC)
74 imports = ["import datetime"]
75 return repr(self.value), set(imports)
76
77
78class DecimalSerializer(BaseSerializer):
79 def serialize(self) -> tuple[str, set[str]]:
80 return repr(self.value), {"from decimal import Decimal"}
81
82
83class DeconstructableSerializer(BaseSerializer):
84 @staticmethod
85 def serialize_deconstructed(
86 path: str, args: tuple[Any, ...], kwargs: dict[str, Any]
87 ) -> tuple[str, set[str]]:
88 name, imports = DeconstructableSerializer._serialize_path(path)
89 strings = []
90 for arg in args:
91 arg_string, arg_imports = serializer_factory(arg).serialize()
92 strings.append(arg_string)
93 imports.update(arg_imports)
94 for kw, arg in sorted(kwargs.items()):
95 arg_string, arg_imports = serializer_factory(arg).serialize()
96 imports.update(arg_imports)
97 strings.append(f"{kw}={arg_string}")
98 return "{}({})".format(name, ", ".join(strings)), imports
99
100 @staticmethod
101 def _serialize_path(path: str) -> tuple[str, set[str]]:
102 module, name = path.rsplit(".", 1)
103 if module == "plain.models":
104 imports: set[str] = {"from plain import models"}
105 name = f"models.{name}"
106 else:
107 imports = {f"import {module}"}
108 name = path
109 return name, imports
110
111 def serialize(self) -> tuple[str, set[str]]:
112 return self.serialize_deconstructed(*self.value.deconstruct())
113
114
115class DictionarySerializer(BaseSerializer):
116 def serialize(self) -> tuple[str, set[str]]:
117 imports: set[str] = set()
118 strings = []
119 for k, v in sorted(self.value.items()):
120 k_string, k_imports = serializer_factory(k).serialize()
121 v_string, v_imports = serializer_factory(v).serialize()
122 imports.update(k_imports)
123 imports.update(v_imports)
124 strings.append((k_string, v_string))
125 return "{{{}}}".format(", ".join(f"{k}: {v}" for k, v in strings)), imports
126
127
128class EnumSerializer(BaseSerializer):
129 def serialize(self) -> tuple[str, set[str]]:
130 enum_class = self.value.__class__
131 module = enum_class.__module__
132 if issubclass(enum_class, enum.Flag):
133 members = list(self.value)
134 else:
135 members = (self.value,)
136 return (
137 " | ".join(
138 [
139 f"{module}.{enum_class.__qualname__}[{item.name!r}]"
140 for item in members
141 ]
142 ),
143 {f"import {module}"},
144 )
145
146
147class FloatSerializer(BaseSimpleSerializer):
148 def serialize(self) -> tuple[str, set[str]]:
149 if math.isnan(self.value) or math.isinf(self.value):
150 return f'float("{self.value}")', set()
151 return super().serialize()
152
153
154class FrozensetSerializer(BaseSequenceSerializer):
155 def _format(self) -> str:
156 return "frozenset([%s])"
157
158
159class FunctionTypeSerializer(BaseSerializer):
160 def serialize(self) -> tuple[str, set[str]]:
161 if getattr(self.value, "__self__", None) and isinstance(
162 self.value.__self__, type
163 ):
164 klass = self.value.__self__
165 module = klass.__module__
166 return f"{module}.{klass.__name__}.{self.value.__name__}", {
167 f"import {module}"
168 }
169 # Further error checking
170 if self.value.__name__ == "<lambda>":
171 raise ValueError("Cannot serialize function: lambda")
172 if self.value.__module__ is None:
173 raise ValueError(f"Cannot serialize function {self.value!r}: No module")
174
175 module_name = self.value.__module__
176
177 if "<" not in self.value.__qualname__: # Qualname can include <locals>
178 return f"{module_name}.{self.value.__qualname__}", {
179 f"import {self.value.__module__}"
180 }
181
182 raise ValueError(
183 f"Could not find function {self.value.__name__} in {module_name}.\n"
184 )
185
186
187class FunctoolsPartialSerializer(BaseSerializer):
188 def serialize(self) -> tuple[str, set[str]]:
189 # Serialize functools.partial() arguments
190 func_string, func_imports = serializer_factory(self.value.func).serialize()
191 args_string, args_imports = serializer_factory(self.value.args).serialize()
192 keywords_string, keywords_imports = serializer_factory(
193 self.value.keywords
194 ).serialize()
195 # Add any imports needed by arguments
196 imports: set[str] = {
197 "import functools",
198 *func_imports,
199 *args_imports,
200 *keywords_imports,
201 }
202 return (
203 f"functools.{self.value.__class__.__name__}({func_string}, *{args_string}, **{keywords_string})",
204 imports,
205 )
206
207
208class IterableSerializer(BaseSerializer):
209 def serialize(self) -> tuple[str, set[str]]:
210 imports: set[str] = set()
211 strings = []
212 for item in self.value:
213 item_string, item_imports = serializer_factory(item).serialize()
214 imports.update(item_imports)
215 strings.append(item_string)
216 # When len(strings)==0, the empty iterable should be serialized as
217 # "()", not "(,)" because (,) is invalid Python syntax.
218 value = "(%s)" if len(strings) != 1 else "(%s,)"
219 return value % (", ".join(strings)), imports
220
221
222class ModelFieldSerializer(DeconstructableSerializer):
223 def serialize(self) -> tuple[str, set[str]]:
224 attr_name, path, args, kwargs = self.value.deconstruct()
225 return self.serialize_deconstructed(path, args, kwargs)
226
227
228class OperationSerializer(BaseSerializer):
229 def serialize(self) -> tuple[str, set[str]]:
230 from plain.models.migrations.writer import OperationWriter
231
232 string, imports = OperationWriter(self.value, indentation=0).serialize()
233 # Nested operation, trailing comma is handled in upper OperationWriter._write()
234 return string.rstrip(","), imports
235
236
237class PathLikeSerializer(BaseSerializer):
238 def serialize(self) -> tuple[str, set[str]]:
239 return repr(os.fspath(self.value)), set()
240
241
242class PathSerializer(BaseSerializer):
243 def serialize(self) -> tuple[str, set[str]]:
244 # Convert concrete paths to pure paths to avoid issues with migrations
245 # generated on one platform being used on a different platform.
246 prefix = "Pure" if isinstance(self.value, pathlib.Path) else ""
247 return f"pathlib.{prefix}{self.value!r}", {"import pathlib"}
248
249
250class RegexSerializer(BaseSerializer):
251 def serialize(self) -> tuple[str, set[str]]:
252 regex_pattern, pattern_imports = serializer_factory(
253 self.value.pattern
254 ).serialize()
255 # Turn off default implicit flags (e.g. re.U) because regexes with the
256 # same implicit and explicit flags aren't equal.
257 flags = self.value.flags ^ re.compile("").flags
258 regex_flags, flag_imports = serializer_factory(flags).serialize()
259 imports: set[str] = {"import re", *pattern_imports, *flag_imports}
260 args = [regex_pattern]
261 if flags:
262 args.append(regex_flags)
263 return "re.compile({})".format(", ".join(args)), imports
264
265
266class SequenceSerializer(BaseSequenceSerializer):
267 def _format(self) -> str:
268 return "[%s]"
269
270
271class SetSerializer(BaseSequenceSerializer):
272 def _format(self) -> str:
273 # Serialize as a set literal except when value is empty because {}
274 # is an empty dict.
275 return "{%s}" if self.value else "set(%s)"
276
277
278class SettingsReferenceSerializer(BaseSerializer):
279 def serialize(self) -> tuple[str, set[str]]:
280 return f"settings.{self.value.setting_name}", {
281 "from plain.runtime import settings"
282 }
283
284
285class TupleSerializer(BaseSequenceSerializer):
286 def _format(self) -> str:
287 # When len(value)==0, the empty tuple should be serialized as "()",
288 # not "(,)" because (,) is invalid Python syntax.
289 return "(%s)" if len(self.value) != 1 else "(%s,)"
290
291
292class TypeSerializer(BaseSerializer):
293 def serialize(self) -> tuple[str, set[str]]:
294 special_cases = [
295 (models.Model, "models.Model", ["from plain import models"]),
296 (types.NoneType, "types.NoneType", ["import types"]),
297 ]
298 for case, string, imports in special_cases:
299 if case is self.value:
300 return string, set(imports)
301 if hasattr(self.value, "__module__"):
302 module = self.value.__module__
303 if module == builtins.__name__:
304 return self.value.__name__, set()
305 else:
306 return f"{module}.{self.value.__qualname__}", {f"import {module}"}
307 return "", set()
308
309
310class UUIDSerializer(BaseSerializer):
311 def serialize(self) -> tuple[str, set[str]]:
312 return f"uuid.{repr(self.value)}", {"import uuid"}
313
314
315class Serializer:
316 _registry = {
317 # Some of these are order-dependent.
318 frozenset: FrozensetSerializer,
319 list: SequenceSerializer,
320 set: SetSerializer,
321 tuple: TupleSerializer,
322 dict: DictionarySerializer,
323 models.Choices: ChoicesSerializer,
324 enum.Enum: EnumSerializer,
325 datetime.datetime: DatetimeDatetimeSerializer,
326 (datetime.date, datetime.timedelta, datetime.time): DateTimeSerializer,
327 SettingsReference: SettingsReferenceSerializer,
328 float: FloatSerializer,
329 (bool, int, types.NoneType, bytes, str, range): BaseSimpleSerializer,
330 decimal.Decimal: DecimalSerializer,
331 (functools.partial, functools.partialmethod): FunctoolsPartialSerializer,
332 (
333 types.FunctionType,
334 types.BuiltinFunctionType,
335 types.MethodType,
336 ): FunctionTypeSerializer,
337 collections.abc.Iterable: IterableSerializer,
338 (COMPILED_REGEX_TYPE, RegexObject): RegexSerializer,
339 uuid.UUID: UUIDSerializer,
340 pathlib.PurePath: PathSerializer,
341 os.PathLike: PathLikeSerializer,
342 }
343
344 @classmethod
345 def register(cls, type_: type[Any], serializer: type[BaseSerializer]) -> None:
346 if not issubclass(serializer, BaseSerializer):
347 raise ValueError(
348 f"'{serializer.__name__}' must inherit from 'BaseSerializer'."
349 )
350 cls._registry[type_] = serializer
351
352
353def serializer_factory(value: Any) -> BaseSerializer:
354 if isinstance(value, Promise):
355 value = str(value)
356 elif isinstance(value, LazyObject):
357 # The unwrapped value is returned as the first item of the arguments
358 # tuple.
359 value = value.__reduce__()[1][0]
360
361 if isinstance(value, models.Field):
362 return ModelFieldSerializer(value)
363 if isinstance(value, Operation):
364 return OperationSerializer(value)
365 if isinstance(value, type):
366 return TypeSerializer(value)
367 # Anything that knows how to deconstruct itself.
368 if hasattr(value, "deconstruct"):
369 return DeconstructableSerializer(value)
370 for type_, serializer_cls in Serializer._registry.items():
371 if isinstance(value, type_):
372 return serializer_cls(value)
373 raise ValueError(
374 f"Cannot serialize: {value!r}\nThere are some values Plain cannot serialize into "
375 "migration files."
376 )