1import builtins
2import collections.abc
3import datetime
4import decimal
5import enum
6import functools
7import math
8import os
9import pathlib
10import re
11import types
12import uuid
13
14from plain import models
15from plain.models.migrations.operations.base import Operation
16from plain.models.migrations.utils import COMPILED_REGEX_TYPE, RegexObject
17from plain.runtime import SettingsReference
18from plain.utils.functional import LazyObject, Promise
19
20
21class BaseSerializer:
22 def __init__(self, value):
23 self.value = value
24
25 def serialize(self):
26 raise NotImplementedError(
27 "Subclasses of BaseSerializer must implement the serialize() method."
28 )
29
30
31class BaseSequenceSerializer(BaseSerializer):
32 def _format(self):
33 raise NotImplementedError(
34 "Subclasses of BaseSequenceSerializer must implement the _format() method."
35 )
36
37 def serialize(self):
38 imports = set()
39 strings = []
40 for item in self.value:
41 item_string, item_imports = serializer_factory(item).serialize()
42 imports.update(item_imports)
43 strings.append(item_string)
44 value = self._format()
45 return value % (", ".join(strings)), imports
46
47
48class BaseSimpleSerializer(BaseSerializer):
49 def serialize(self):
50 return repr(self.value), set()
51
52
53class ChoicesSerializer(BaseSerializer):
54 def serialize(self):
55 return serializer_factory(self.value.value).serialize()
56
57
58class DateTimeSerializer(BaseSerializer):
59 """For datetime.*, except datetime.datetime."""
60
61 def serialize(self):
62 return repr(self.value), {"import datetime"}
63
64
65class DatetimeDatetimeSerializer(BaseSerializer):
66 """For datetime.datetime."""
67
68 def serialize(self):
69 if self.value.tzinfo is not None and self.value.tzinfo != datetime.UTC:
70 self.value = self.value.astimezone(datetime.UTC)
71 imports = ["import datetime"]
72 return repr(self.value), set(imports)
73
74
75class DecimalSerializer(BaseSerializer):
76 def serialize(self):
77 return repr(self.value), {"from decimal import Decimal"}
78
79
80class DeconstructableSerializer(BaseSerializer):
81 @staticmethod
82 def serialize_deconstructed(path, args, kwargs):
83 name, imports = DeconstructableSerializer._serialize_path(path)
84 strings = []
85 for arg in args:
86 arg_string, arg_imports = serializer_factory(arg).serialize()
87 strings.append(arg_string)
88 imports.update(arg_imports)
89 for kw, arg in sorted(kwargs.items()):
90 arg_string, arg_imports = serializer_factory(arg).serialize()
91 imports.update(arg_imports)
92 strings.append(f"{kw}={arg_string}")
93 return "{}({})".format(name, ", ".join(strings)), imports
94
95 @staticmethod
96 def _serialize_path(path):
97 module, name = path.rsplit(".", 1)
98 if module == "plain.models":
99 imports = {"from plain import models"}
100 name = f"models.{name}"
101 else:
102 imports = {f"import {module}"}
103 name = path
104 return name, imports
105
106 def serialize(self):
107 return self.serialize_deconstructed(*self.value.deconstruct())
108
109
110class DictionarySerializer(BaseSerializer):
111 def serialize(self):
112 imports = set()
113 strings = []
114 for k, v in sorted(self.value.items()):
115 k_string, k_imports = serializer_factory(k).serialize()
116 v_string, v_imports = serializer_factory(v).serialize()
117 imports.update(k_imports)
118 imports.update(v_imports)
119 strings.append((k_string, v_string))
120 return "{{{}}}".format(", ".join(f"{k}: {v}" for k, v in strings)), imports
121
122
123class EnumSerializer(BaseSerializer):
124 def serialize(self):
125 enum_class = self.value.__class__
126 module = enum_class.__module__
127 if issubclass(enum_class, enum.Flag):
128 members = list(self.value)
129 else:
130 members = (self.value,)
131 return (
132 " | ".join(
133 [
134 f"{module}.{enum_class.__qualname__}[{item.name!r}]"
135 for item in members
136 ]
137 ),
138 {f"import {module}"},
139 )
140
141
142class FloatSerializer(BaseSimpleSerializer):
143 def serialize(self):
144 if math.isnan(self.value) or math.isinf(self.value):
145 return f'float("{self.value}")', set()
146 return super().serialize()
147
148
149class FrozensetSerializer(BaseSequenceSerializer):
150 def _format(self):
151 return "frozenset([%s])"
152
153
154class FunctionTypeSerializer(BaseSerializer):
155 def serialize(self):
156 if getattr(self.value, "__self__", None) and isinstance(
157 self.value.__self__, type
158 ):
159 klass = self.value.__self__
160 module = klass.__module__
161 return f"{module}.{klass.__name__}.{self.value.__name__}", {
162 f"import {module}"
163 }
164 # Further error checking
165 if self.value.__name__ == "<lambda>":
166 raise ValueError("Cannot serialize function: lambda")
167 if self.value.__module__ is None:
168 raise ValueError(f"Cannot serialize function {self.value!r}: No module")
169
170 module_name = self.value.__module__
171
172 if "<" not in self.value.__qualname__: # Qualname can include <locals>
173 return f"{module_name}.{self.value.__qualname__}", {
174 f"import {self.value.__module__}"
175 }
176
177 raise ValueError(
178 f"Could not find function {self.value.__name__} in {module_name}.\n"
179 )
180
181
182class FunctoolsPartialSerializer(BaseSerializer):
183 def serialize(self):
184 # Serialize functools.partial() arguments
185 func_string, func_imports = serializer_factory(self.value.func).serialize()
186 args_string, args_imports = serializer_factory(self.value.args).serialize()
187 keywords_string, keywords_imports = serializer_factory(
188 self.value.keywords
189 ).serialize()
190 # Add any imports needed by arguments
191 imports = {"import functools", *func_imports, *args_imports, *keywords_imports}
192 return (
193 f"functools.{self.value.__class__.__name__}({func_string}, *{args_string}, **{keywords_string})",
194 imports,
195 )
196
197
198class IterableSerializer(BaseSerializer):
199 def serialize(self):
200 imports = set()
201 strings = []
202 for item in self.value:
203 item_string, item_imports = serializer_factory(item).serialize()
204 imports.update(item_imports)
205 strings.append(item_string)
206 # When len(strings)==0, the empty iterable should be serialized as
207 # "()", not "(,)" because (,) is invalid Python syntax.
208 value = "(%s)" if len(strings) != 1 else "(%s,)"
209 return value % (", ".join(strings)), imports
210
211
212class ModelFieldSerializer(DeconstructableSerializer):
213 def serialize(self):
214 attr_name, path, args, kwargs = self.value.deconstruct()
215 return self.serialize_deconstructed(path, args, kwargs)
216
217
218class OperationSerializer(BaseSerializer):
219 def serialize(self):
220 from plain.models.migrations.writer import OperationWriter
221
222 string, imports = OperationWriter(self.value, indentation=0).serialize()
223 # Nested operation, trailing comma is handled in upper OperationWriter._write()
224 return string.rstrip(","), imports
225
226
227class PathLikeSerializer(BaseSerializer):
228 def serialize(self):
229 return repr(os.fspath(self.value)), {}
230
231
232class PathSerializer(BaseSerializer):
233 def serialize(self):
234 # Convert concrete paths to pure paths to avoid issues with migrations
235 # generated on one platform being used on a different platform.
236 prefix = "Pure" if isinstance(self.value, pathlib.Path) else ""
237 return f"pathlib.{prefix}{self.value!r}", {"import pathlib"}
238
239
240class RegexSerializer(BaseSerializer):
241 def serialize(self):
242 regex_pattern, pattern_imports = serializer_factory(
243 self.value.pattern
244 ).serialize()
245 # Turn off default implicit flags (e.g. re.U) because regexes with the
246 # same implicit and explicit flags aren't equal.
247 flags = self.value.flags ^ re.compile("").flags
248 regex_flags, flag_imports = serializer_factory(flags).serialize()
249 imports = {"import re", *pattern_imports, *flag_imports}
250 args = [regex_pattern]
251 if flags:
252 args.append(regex_flags)
253 return "re.compile({})".format(", ".join(args)), imports
254
255
256class SequenceSerializer(BaseSequenceSerializer):
257 def _format(self):
258 return "[%s]"
259
260
261class SetSerializer(BaseSequenceSerializer):
262 def _format(self):
263 # Serialize as a set literal except when value is empty because {}
264 # is an empty dict.
265 return "{%s}" if self.value else "set(%s)"
266
267
268class SettingsReferenceSerializer(BaseSerializer):
269 def serialize(self):
270 return f"settings.{self.value.setting_name}", {
271 "from plain.runtime import settings"
272 }
273
274
275class TupleSerializer(BaseSequenceSerializer):
276 def _format(self):
277 # When len(value)==0, the empty tuple should be serialized as "()",
278 # not "(,)" because (,) is invalid Python syntax.
279 return "(%s)" if len(self.value) != 1 else "(%s,)"
280
281
282class TypeSerializer(BaseSerializer):
283 def serialize(self):
284 special_cases = [
285 (models.Model, "models.Model", ["from plain import models"]),
286 (types.NoneType, "types.NoneType", ["import types"]),
287 ]
288 for case, string, imports in special_cases:
289 if case is self.value:
290 return string, set(imports)
291 if hasattr(self.value, "__module__"):
292 module = self.value.__module__
293 if module == builtins.__name__:
294 return self.value.__name__, set()
295 else:
296 return f"{module}.{self.value.__qualname__}", {f"import {module}"}
297
298
299class UUIDSerializer(BaseSerializer):
300 def serialize(self):
301 return f"uuid.{repr(self.value)}", {"import uuid"}
302
303
304class Serializer:
305 _registry = {
306 # Some of these are order-dependent.
307 frozenset: FrozensetSerializer,
308 list: SequenceSerializer,
309 set: SetSerializer,
310 tuple: TupleSerializer,
311 dict: DictionarySerializer,
312 models.Choices: ChoicesSerializer,
313 enum.Enum: EnumSerializer,
314 datetime.datetime: DatetimeDatetimeSerializer,
315 (datetime.date, datetime.timedelta, datetime.time): DateTimeSerializer,
316 SettingsReference: SettingsReferenceSerializer,
317 float: FloatSerializer,
318 (bool, int, types.NoneType, bytes, str, range): BaseSimpleSerializer,
319 decimal.Decimal: DecimalSerializer,
320 (functools.partial, functools.partialmethod): FunctoolsPartialSerializer,
321 (
322 types.FunctionType,
323 types.BuiltinFunctionType,
324 types.MethodType,
325 ): FunctionTypeSerializer,
326 collections.abc.Iterable: IterableSerializer,
327 (COMPILED_REGEX_TYPE, RegexObject): RegexSerializer,
328 uuid.UUID: UUIDSerializer,
329 pathlib.PurePath: PathSerializer,
330 os.PathLike: PathLikeSerializer,
331 }
332
333 @classmethod
334 def register(cls, type_, serializer):
335 if not issubclass(serializer, BaseSerializer):
336 raise ValueError(
337 f"'{serializer.__name__}' must inherit from 'BaseSerializer'."
338 )
339 cls._registry[type_] = serializer
340
341
342def serializer_factory(value):
343 if isinstance(value, Promise):
344 value = str(value)
345 elif isinstance(value, LazyObject):
346 # The unwrapped value is returned as the first item of the arguments
347 # tuple.
348 value = value.__reduce__()[1][0]
349
350 if isinstance(value, models.Field):
351 return ModelFieldSerializer(value)
352 if isinstance(value, Operation):
353 return OperationSerializer(value)
354 if isinstance(value, type):
355 return TypeSerializer(value)
356 # Anything that knows how to deconstruct itself.
357 if hasattr(value, "deconstruct"):
358 return DeconstructableSerializer(value)
359 for type_, serializer_cls in Serializer._registry.items():
360 if isinstance(value, type_):
361 return serializer_cls(value)
362 raise ValueError(
363 f"Cannot serialize: {value!r}\nThere are some values Plain cannot serialize into "
364 "migration files."
365 )