1import builtins
2import collections.abc
3import datetime
4import decimal
5import enum
6import functools
7import math
8import os
9import pathlib
10import re
11import types
12import uuid
13
14from plain import models
15from plain.models.migrations.operations.base import Operation
16from plain.models.migrations.utils import COMPILED_REGEX_TYPE, RegexObject
17from plain.runtime.user_settings import SettingsReference
18from plain.utils.functional import LazyObject, Promise
19
20
21class BaseSerializer:
22 def __init__(self, value):
23 self.value = value
24
25 def serialize(self):
26 raise NotImplementedError(
27 "Subclasses of BaseSerializer must implement the serialize() method."
28 )
29
30
31class BaseSequenceSerializer(BaseSerializer):
32 def _format(self):
33 raise NotImplementedError(
34 "Subclasses of BaseSequenceSerializer must implement the _format() method."
35 )
36
37 def serialize(self):
38 imports = set()
39 strings = []
40 for item in self.value:
41 item_string, item_imports = serializer_factory(item).serialize()
42 imports.update(item_imports)
43 strings.append(item_string)
44 value = self._format()
45 return value % (", ".join(strings)), imports
46
47
48class BaseSimpleSerializer(BaseSerializer):
49 def serialize(self):
50 return repr(self.value), set()
51
52
53class ChoicesSerializer(BaseSerializer):
54 def serialize(self):
55 return serializer_factory(self.value.value).serialize()
56
57
58class DateTimeSerializer(BaseSerializer):
59 """For datetime.*, except datetime.datetime."""
60
61 def serialize(self):
62 return repr(self.value), {"import datetime"}
63
64
65class DatetimeDatetimeSerializer(BaseSerializer):
66 """For datetime.datetime."""
67
68 def serialize(self):
69 if self.value.tzinfo is not None and self.value.tzinfo != datetime.timezone.utc:
70 self.value = self.value.astimezone(datetime.timezone.utc)
71 imports = ["import datetime"]
72 return repr(self.value), set(imports)
73
74
75class DecimalSerializer(BaseSerializer):
76 def serialize(self):
77 return repr(self.value), {"from decimal import Decimal"}
78
79
80class DeconstructableSerializer(BaseSerializer):
81 @staticmethod
82 def serialize_deconstructed(path, args, kwargs):
83 name, imports = DeconstructableSerializer._serialize_path(path)
84 strings = []
85 for arg in args:
86 arg_string, arg_imports = serializer_factory(arg).serialize()
87 strings.append(arg_string)
88 imports.update(arg_imports)
89 for kw, arg in sorted(kwargs.items()):
90 arg_string, arg_imports = serializer_factory(arg).serialize()
91 imports.update(arg_imports)
92 strings.append(f"{kw}={arg_string}")
93 return "{}({})".format(name, ", ".join(strings)), imports
94
95 @staticmethod
96 def _serialize_path(path):
97 module, name = path.rsplit(".", 1)
98 if module == "plain.models":
99 imports = {"from plain import models"}
100 name = "models.%s" % name
101 else:
102 imports = {"import %s" % module}
103 name = path
104 return name, imports
105
106 def serialize(self):
107 return self.serialize_deconstructed(*self.value.deconstruct())
108
109
110class DictionarySerializer(BaseSerializer):
111 def serialize(self):
112 imports = set()
113 strings = []
114 for k, v in sorted(self.value.items()):
115 k_string, k_imports = serializer_factory(k).serialize()
116 v_string, v_imports = serializer_factory(v).serialize()
117 imports.update(k_imports)
118 imports.update(v_imports)
119 strings.append((k_string, v_string))
120 return "{%s}" % (", ".join(f"{k}: {v}" for k, v in strings)), imports
121
122
123class EnumSerializer(BaseSerializer):
124 def serialize(self):
125 enum_class = self.value.__class__
126 module = enum_class.__module__
127 if issubclass(enum_class, enum.Flag):
128 members = list(self.value)
129 else:
130 members = (self.value,)
131 return (
132 " | ".join(
133 [
134 f"{module}.{enum_class.__qualname__}[{item.name!r}]"
135 for item in members
136 ]
137 ),
138 {"import %s" % module},
139 )
140
141
142class FloatSerializer(BaseSimpleSerializer):
143 def serialize(self):
144 if math.isnan(self.value) or math.isinf(self.value):
145 return f'float("{self.value}")', set()
146 return super().serialize()
147
148
149class FrozensetSerializer(BaseSequenceSerializer):
150 def _format(self):
151 return "frozenset([%s])"
152
153
154class FunctionTypeSerializer(BaseSerializer):
155 def serialize(self):
156 if getattr(self.value, "__self__", None) and isinstance(
157 self.value.__self__, type
158 ):
159 klass = self.value.__self__
160 module = klass.__module__
161 return f"{module}.{klass.__name__}.{self.value.__name__}", {
162 "import %s" % module
163 }
164 # Further error checking
165 if self.value.__name__ == "<lambda>":
166 raise ValueError("Cannot serialize function: lambda")
167 if self.value.__module__ is None:
168 raise ValueError("Cannot serialize function %r: No module" % self.value)
169
170 module_name = self.value.__module__
171
172 if "<" not in self.value.__qualname__: # Qualname can include <locals>
173 return f"{module_name}.{self.value.__qualname__}", {
174 "import %s" % self.value.__module__
175 }
176
177 raise ValueError(
178 f"Could not find function {self.value.__name__} in {module_name}.\n"
179 )
180
181
182class FunctoolsPartialSerializer(BaseSerializer):
183 def serialize(self):
184 # Serialize functools.partial() arguments
185 func_string, func_imports = serializer_factory(self.value.func).serialize()
186 args_string, args_imports = serializer_factory(self.value.args).serialize()
187 keywords_string, keywords_imports = serializer_factory(
188 self.value.keywords
189 ).serialize()
190 # Add any imports needed by arguments
191 imports = {"import functools", *func_imports, *args_imports, *keywords_imports}
192 return (
193 "functools.{}({}, *{}, **{})".format(
194 self.value.__class__.__name__,
195 func_string,
196 args_string,
197 keywords_string,
198 ),
199 imports,
200 )
201
202
203class IterableSerializer(BaseSerializer):
204 def serialize(self):
205 imports = set()
206 strings = []
207 for item in self.value:
208 item_string, item_imports = serializer_factory(item).serialize()
209 imports.update(item_imports)
210 strings.append(item_string)
211 # When len(strings)==0, the empty iterable should be serialized as
212 # "()", not "(,)" because (,) is invalid Python syntax.
213 value = "(%s)" if len(strings) != 1 else "(%s,)"
214 return value % (", ".join(strings)), imports
215
216
217class ModelFieldSerializer(DeconstructableSerializer):
218 def serialize(self):
219 attr_name, path, args, kwargs = self.value.deconstruct()
220 return self.serialize_deconstructed(path, args, kwargs)
221
222
223class ModelManagerSerializer(DeconstructableSerializer):
224 def serialize(self):
225 as_manager, manager_path, qs_path, args, kwargs = self.value.deconstruct()
226 if as_manager:
227 name, imports = self._serialize_path(qs_path)
228 return "%s.as_manager()" % name, imports
229 else:
230 return self.serialize_deconstructed(manager_path, args, kwargs)
231
232
233class OperationSerializer(BaseSerializer):
234 def serialize(self):
235 from plain.models.migrations.writer import OperationWriter
236
237 string, imports = OperationWriter(self.value, indentation=0).serialize()
238 # Nested operation, trailing comma is handled in upper OperationWriter._write()
239 return string.rstrip(","), imports
240
241
242class PathLikeSerializer(BaseSerializer):
243 def serialize(self):
244 return repr(os.fspath(self.value)), {}
245
246
247class PathSerializer(BaseSerializer):
248 def serialize(self):
249 # Convert concrete paths to pure paths to avoid issues with migrations
250 # generated on one platform being used on a different platform.
251 prefix = "Pure" if isinstance(self.value, pathlib.Path) else ""
252 return f"pathlib.{prefix}{self.value!r}", {"import pathlib"}
253
254
255class RegexSerializer(BaseSerializer):
256 def serialize(self):
257 regex_pattern, pattern_imports = serializer_factory(
258 self.value.pattern
259 ).serialize()
260 # Turn off default implicit flags (e.g. re.U) because regexes with the
261 # same implicit and explicit flags aren't equal.
262 flags = self.value.flags ^ re.compile("").flags
263 regex_flags, flag_imports = serializer_factory(flags).serialize()
264 imports = {"import re", *pattern_imports, *flag_imports}
265 args = [regex_pattern]
266 if flags:
267 args.append(regex_flags)
268 return "re.compile(%s)" % ", ".join(args), imports
269
270
271class SequenceSerializer(BaseSequenceSerializer):
272 def _format(self):
273 return "[%s]"
274
275
276class SetSerializer(BaseSequenceSerializer):
277 def _format(self):
278 # Serialize as a set literal except when value is empty because {}
279 # is an empty dict.
280 return "{%s}" if self.value else "set(%s)"
281
282
283class SettingsReferenceSerializer(BaseSerializer):
284 def serialize(self):
285 return "settings.%s" % self.value.setting_name, {
286 "from plain.runtime import settings"
287 }
288
289
290class TupleSerializer(BaseSequenceSerializer):
291 def _format(self):
292 # When len(value)==0, the empty tuple should be serialized as "()",
293 # not "(,)" because (,) is invalid Python syntax.
294 return "(%s)" if len(self.value) != 1 else "(%s,)"
295
296
297class TypeSerializer(BaseSerializer):
298 def serialize(self):
299 special_cases = [
300 (models.Model, "models.Model", ["from plain import models"]),
301 (types.NoneType, "types.NoneType", ["import types"]),
302 ]
303 for case, string, imports in special_cases:
304 if case is self.value:
305 return string, set(imports)
306 if hasattr(self.value, "__module__"):
307 module = self.value.__module__
308 if module == builtins.__name__:
309 return self.value.__name__, set()
310 else:
311 return f"{module}.{self.value.__qualname__}", {"import %s" % module}
312
313
314class UUIDSerializer(BaseSerializer):
315 def serialize(self):
316 return "uuid.%s" % repr(self.value), {"import uuid"}
317
318
319class Serializer:
320 _registry = {
321 # Some of these are order-dependent.
322 frozenset: FrozensetSerializer,
323 list: SequenceSerializer,
324 set: SetSerializer,
325 tuple: TupleSerializer,
326 dict: DictionarySerializer,
327 models.Choices: ChoicesSerializer,
328 enum.Enum: EnumSerializer,
329 datetime.datetime: DatetimeDatetimeSerializer,
330 (datetime.date, datetime.timedelta, datetime.time): DateTimeSerializer,
331 SettingsReference: SettingsReferenceSerializer,
332 float: FloatSerializer,
333 (bool, int, types.NoneType, bytes, str, range): BaseSimpleSerializer,
334 decimal.Decimal: DecimalSerializer,
335 (functools.partial, functools.partialmethod): FunctoolsPartialSerializer,
336 (
337 types.FunctionType,
338 types.BuiltinFunctionType,
339 types.MethodType,
340 ): FunctionTypeSerializer,
341 collections.abc.Iterable: IterableSerializer,
342 (COMPILED_REGEX_TYPE, RegexObject): RegexSerializer,
343 uuid.UUID: UUIDSerializer,
344 pathlib.PurePath: PathSerializer,
345 os.PathLike: PathLikeSerializer,
346 }
347
348 @classmethod
349 def register(cls, type_, serializer):
350 if not issubclass(serializer, BaseSerializer):
351 raise ValueError(
352 "'%s' must inherit from 'BaseSerializer'." % serializer.__name__
353 )
354 cls._registry[type_] = serializer
355
356 @classmethod
357 def unregister(cls, type_):
358 cls._registry.pop(type_)
359
360
361def serializer_factory(value):
362 if isinstance(value, Promise):
363 value = str(value)
364 elif isinstance(value, LazyObject):
365 # The unwrapped value is returned as the first item of the arguments
366 # tuple.
367 value = value.__reduce__()[1][0]
368
369 if isinstance(value, models.Field):
370 return ModelFieldSerializer(value)
371 if isinstance(value, models.manager.BaseManager):
372 return ModelManagerSerializer(value)
373 if isinstance(value, Operation):
374 return OperationSerializer(value)
375 if isinstance(value, type):
376 return TypeSerializer(value)
377 # Anything that knows how to deconstruct itself.
378 if hasattr(value, "deconstruct"):
379 return DeconstructableSerializer(value)
380 for type_, serializer_cls in Serializer._registry.items():
381 if isinstance(value, type_):
382 return serializer_cls(value)
383 raise ValueError(
384 "Cannot serialize: %r\nThere are some values Plain cannot serialize into "
385 "migration files." % value
386 )