1import builtins
2import collections.abc
3import datetime
4import decimal
5import enum
6import functools
7import math
8import os
9import pathlib
10import re
11import types
12import uuid
13
14from plain import models
15from plain.models.migrations.operations.base import Operation
16from plain.models.migrations.utils import COMPILED_REGEX_TYPE, RegexObject
17from plain.runtime import SettingsReference
18from plain.utils.functional import LazyObject, Promise
19
20
21class BaseSerializer:
22 def __init__(self, value):
23 self.value = value
24
25 def serialize(self):
26 raise NotImplementedError(
27 "Subclasses of BaseSerializer must implement the serialize() method."
28 )
29
30
31class BaseSequenceSerializer(BaseSerializer):
32 def _format(self):
33 raise NotImplementedError(
34 "Subclasses of BaseSequenceSerializer must implement the _format() method."
35 )
36
37 def serialize(self):
38 imports = set()
39 strings = []
40 for item in self.value:
41 item_string, item_imports = serializer_factory(item).serialize()
42 imports.update(item_imports)
43 strings.append(item_string)
44 value = self._format()
45 return value % (", ".join(strings)), imports
46
47
48class BaseSimpleSerializer(BaseSerializer):
49 def serialize(self):
50 return repr(self.value), set()
51
52
53class ChoicesSerializer(BaseSerializer):
54 def serialize(self):
55 return serializer_factory(self.value.value).serialize()
56
57
58class DateTimeSerializer(BaseSerializer):
59 """For datetime.*, except datetime.datetime."""
60
61 def serialize(self):
62 return repr(self.value), {"import datetime"}
63
64
65class DatetimeDatetimeSerializer(BaseSerializer):
66 """For datetime.datetime."""
67
68 def serialize(self):
69 if self.value.tzinfo is not None and self.value.tzinfo != datetime.UTC:
70 self.value = self.value.astimezone(datetime.UTC)
71 imports = ["import datetime"]
72 return repr(self.value), set(imports)
73
74
75class DecimalSerializer(BaseSerializer):
76 def serialize(self):
77 return repr(self.value), {"from decimal import Decimal"}
78
79
80class DeconstructableSerializer(BaseSerializer):
81 @staticmethod
82 def serialize_deconstructed(path, args, kwargs):
83 name, imports = DeconstructableSerializer._serialize_path(path)
84 strings = []
85 for arg in args:
86 arg_string, arg_imports = serializer_factory(arg).serialize()
87 strings.append(arg_string)
88 imports.update(arg_imports)
89 for kw, arg in sorted(kwargs.items()):
90 arg_string, arg_imports = serializer_factory(arg).serialize()
91 imports.update(arg_imports)
92 strings.append(f"{kw}={arg_string}")
93 return "{}({})".format(name, ", ".join(strings)), imports
94
95 @staticmethod
96 def _serialize_path(path):
97 module, name = path.rsplit(".", 1)
98 if module == "plain.models":
99 imports = {"from plain import models"}
100 name = f"models.{name}"
101 else:
102 imports = {f"import {module}"}
103 name = path
104 return name, imports
105
106 def serialize(self):
107 return self.serialize_deconstructed(*self.value.deconstruct())
108
109
110class DictionarySerializer(BaseSerializer):
111 def serialize(self):
112 imports = set()
113 strings = []
114 for k, v in sorted(self.value.items()):
115 k_string, k_imports = serializer_factory(k).serialize()
116 v_string, v_imports = serializer_factory(v).serialize()
117 imports.update(k_imports)
118 imports.update(v_imports)
119 strings.append((k_string, v_string))
120 return "{{{}}}".format(", ".join(f"{k}: {v}" for k, v in strings)), imports
121
122
123class EnumSerializer(BaseSerializer):
124 def serialize(self):
125 enum_class = self.value.__class__
126 module = enum_class.__module__
127 if issubclass(enum_class, enum.Flag):
128 members = list(self.value)
129 else:
130 members = (self.value,)
131 return (
132 " | ".join(
133 [
134 f"{module}.{enum_class.__qualname__}[{item.name!r}]"
135 for item in members
136 ]
137 ),
138 {f"import {module}"},
139 )
140
141
142class FloatSerializer(BaseSimpleSerializer):
143 def serialize(self):
144 if math.isnan(self.value) or math.isinf(self.value):
145 return f'float("{self.value}")', set()
146 return super().serialize()
147
148
149class FrozensetSerializer(BaseSequenceSerializer):
150 def _format(self):
151 return "frozenset([%s])"
152
153
154class FunctionTypeSerializer(BaseSerializer):
155 def serialize(self):
156 if getattr(self.value, "__self__", None) and isinstance(
157 self.value.__self__, type
158 ):
159 klass = self.value.__self__
160 module = klass.__module__
161 return f"{module}.{klass.__name__}.{self.value.__name__}", {
162 f"import {module}"
163 }
164 # Further error checking
165 if self.value.__name__ == "<lambda>":
166 raise ValueError("Cannot serialize function: lambda")
167 if self.value.__module__ is None:
168 raise ValueError(f"Cannot serialize function {self.value!r}: No module")
169
170 module_name = self.value.__module__
171
172 if "<" not in self.value.__qualname__: # Qualname can include <locals>
173 return f"{module_name}.{self.value.__qualname__}", {
174 f"import {self.value.__module__}"
175 }
176
177 raise ValueError(
178 f"Could not find function {self.value.__name__} in {module_name}.\n"
179 )
180
181
182class FunctoolsPartialSerializer(BaseSerializer):
183 def serialize(self):
184 # Serialize functools.partial() arguments
185 func_string, func_imports = serializer_factory(self.value.func).serialize()
186 args_string, args_imports = serializer_factory(self.value.args).serialize()
187 keywords_string, keywords_imports = serializer_factory(
188 self.value.keywords
189 ).serialize()
190 # Add any imports needed by arguments
191 imports = {"import functools", *func_imports, *args_imports, *keywords_imports}
192 return (
193 f"functools.{self.value.__class__.__name__}({func_string}, *{args_string}, **{keywords_string})",
194 imports,
195 )
196
197
198class IterableSerializer(BaseSerializer):
199 def serialize(self):
200 imports = set()
201 strings = []
202 for item in self.value:
203 item_string, item_imports = serializer_factory(item).serialize()
204 imports.update(item_imports)
205 strings.append(item_string)
206 # When len(strings)==0, the empty iterable should be serialized as
207 # "()", not "(,)" because (,) is invalid Python syntax.
208 value = "(%s)" if len(strings) != 1 else "(%s,)"
209 return value % (", ".join(strings)), imports
210
211
212class ModelFieldSerializer(DeconstructableSerializer):
213 def serialize(self):
214 attr_name, path, args, kwargs = self.value.deconstruct()
215 return self.serialize_deconstructed(path, args, kwargs)
216
217
218class ModelManagerSerializer(DeconstructableSerializer):
219 def serialize(self):
220 as_manager, manager_path, qs_path, args, kwargs = self.value.deconstruct()
221 if as_manager:
222 name, imports = self._serialize_path(qs_path)
223 return f"{name}.as_manager()", imports
224 else:
225 return self.serialize_deconstructed(manager_path, args, kwargs)
226
227
228class OperationSerializer(BaseSerializer):
229 def serialize(self):
230 from plain.models.migrations.writer import OperationWriter
231
232 string, imports = OperationWriter(self.value, indentation=0).serialize()
233 # Nested operation, trailing comma is handled in upper OperationWriter._write()
234 return string.rstrip(","), imports
235
236
237class PathLikeSerializer(BaseSerializer):
238 def serialize(self):
239 return repr(os.fspath(self.value)), {}
240
241
242class PathSerializer(BaseSerializer):
243 def serialize(self):
244 # Convert concrete paths to pure paths to avoid issues with migrations
245 # generated on one platform being used on a different platform.
246 prefix = "Pure" if isinstance(self.value, pathlib.Path) else ""
247 return f"pathlib.{prefix}{self.value!r}", {"import pathlib"}
248
249
250class RegexSerializer(BaseSerializer):
251 def serialize(self):
252 regex_pattern, pattern_imports = serializer_factory(
253 self.value.pattern
254 ).serialize()
255 # Turn off default implicit flags (e.g. re.U) because regexes with the
256 # same implicit and explicit flags aren't equal.
257 flags = self.value.flags ^ re.compile("").flags
258 regex_flags, flag_imports = serializer_factory(flags).serialize()
259 imports = {"import re", *pattern_imports, *flag_imports}
260 args = [regex_pattern]
261 if flags:
262 args.append(regex_flags)
263 return "re.compile({})".format(", ".join(args)), imports
264
265
266class SequenceSerializer(BaseSequenceSerializer):
267 def _format(self):
268 return "[%s]"
269
270
271class SetSerializer(BaseSequenceSerializer):
272 def _format(self):
273 # Serialize as a set literal except when value is empty because {}
274 # is an empty dict.
275 return "{%s}" if self.value else "set(%s)"
276
277
278class SettingsReferenceSerializer(BaseSerializer):
279 def serialize(self):
280 return f"settings.{self.value.setting_name}", {
281 "from plain.runtime import settings"
282 }
283
284
285class TupleSerializer(BaseSequenceSerializer):
286 def _format(self):
287 # When len(value)==0, the empty tuple should be serialized as "()",
288 # not "(,)" because (,) is invalid Python syntax.
289 return "(%s)" if len(self.value) != 1 else "(%s,)"
290
291
292class TypeSerializer(BaseSerializer):
293 def serialize(self):
294 special_cases = [
295 (models.Model, "models.Model", ["from plain import models"]),
296 (types.NoneType, "types.NoneType", ["import types"]),
297 ]
298 for case, string, imports in special_cases:
299 if case is self.value:
300 return string, set(imports)
301 if hasattr(self.value, "__module__"):
302 module = self.value.__module__
303 if module == builtins.__name__:
304 return self.value.__name__, set()
305 else:
306 return f"{module}.{self.value.__qualname__}", {f"import {module}"}
307
308
309class UUIDSerializer(BaseSerializer):
310 def serialize(self):
311 return f"uuid.{repr(self.value)}", {"import uuid"}
312
313
314class Serializer:
315 _registry = {
316 # Some of these are order-dependent.
317 frozenset: FrozensetSerializer,
318 list: SequenceSerializer,
319 set: SetSerializer,
320 tuple: TupleSerializer,
321 dict: DictionarySerializer,
322 models.Choices: ChoicesSerializer,
323 enum.Enum: EnumSerializer,
324 datetime.datetime: DatetimeDatetimeSerializer,
325 (datetime.date, datetime.timedelta, datetime.time): DateTimeSerializer,
326 SettingsReference: SettingsReferenceSerializer,
327 float: FloatSerializer,
328 (bool, int, types.NoneType, bytes, str, range): BaseSimpleSerializer,
329 decimal.Decimal: DecimalSerializer,
330 (functools.partial, functools.partialmethod): FunctoolsPartialSerializer,
331 (
332 types.FunctionType,
333 types.BuiltinFunctionType,
334 types.MethodType,
335 ): FunctionTypeSerializer,
336 collections.abc.Iterable: IterableSerializer,
337 (COMPILED_REGEX_TYPE, RegexObject): RegexSerializer,
338 uuid.UUID: UUIDSerializer,
339 pathlib.PurePath: PathSerializer,
340 os.PathLike: PathLikeSerializer,
341 }
342
343 @classmethod
344 def register(cls, type_, serializer):
345 if not issubclass(serializer, BaseSerializer):
346 raise ValueError(
347 f"'{serializer.__name__}' must inherit from 'BaseSerializer'."
348 )
349 cls._registry[type_] = serializer
350
351
352def serializer_factory(value):
353 if isinstance(value, Promise):
354 value = str(value)
355 elif isinstance(value, LazyObject):
356 # The unwrapped value is returned as the first item of the arguments
357 # tuple.
358 value = value.__reduce__()[1][0]
359
360 if isinstance(value, models.Field):
361 return ModelFieldSerializer(value)
362 if isinstance(value, models.manager.BaseManager):
363 return ModelManagerSerializer(value)
364 if isinstance(value, Operation):
365 return OperationSerializer(value)
366 if isinstance(value, type):
367 return TypeSerializer(value)
368 # Anything that knows how to deconstruct itself.
369 if hasattr(value, "deconstruct"):
370 return DeconstructableSerializer(value)
371 for type_, serializer_cls in Serializer._registry.items():
372 if isinstance(value, type_):
373 return serializer_cls(value)
374 raise ValueError(
375 f"Cannot serialize: {value!r}\nThere are some values Plain cannot serialize into "
376 "migration files."
377 )