1from __future__ import annotations
2
3import json
4from collections.abc import Callable
5from typing import TYPE_CHECKING, Any
6
7from plain import exceptions, preflight
8from plain.models import expressions, lookups
9from plain.models.constants import LOOKUP_SEP
10from plain.models.fields import TextField
11from plain.models.lookups import (
12 FieldGetDbPrepValueMixin,
13 Lookup,
14 OperatorLookup,
15 Transform,
16)
17from plain.models.postgres.sql import adapt_json_value
18
19from . import Field
20
21if TYPE_CHECKING:
22 from plain.models.postgres.wrapper import DatabaseWrapper
23 from plain.models.sql.compiler import SQLCompiler
24 from plain.preflight.results import PreflightResult
25
26__all__ = ["JSONField"]
27
28
29class JSONField(Field):
30 empty_strings_allowed = False
31 description = "A JSON object"
32 default_error_messages = {
33 "invalid": "Value must be valid JSON.",
34 }
35 _default_fix = ("dict", "{}")
36
37 def __init__(
38 self,
39 *,
40 encoder: type[json.JSONEncoder] | None = None,
41 decoder: type[json.JSONDecoder] | None = None,
42 **kwargs: Any,
43 ):
44 if encoder and not callable(encoder):
45 raise ValueError("The encoder parameter must be a callable object.")
46 if decoder and not callable(decoder):
47 raise ValueError("The decoder parameter must be a callable object.")
48 self.encoder = encoder
49 self.decoder = decoder
50 super().__init__(**kwargs)
51
52 def _check_default(self) -> list[PreflightResult]:
53 if (
54 self.has_default()
55 and self.default is not None
56 and not callable(self.default)
57 ):
58 return [
59 preflight.PreflightResult(
60 fix=(
61 f"{self.__class__.__name__} default should be a callable instead of an instance "
62 "so that it's not shared between all field instances. "
63 "Use a callable instead, e.g., use `{}` instead of "
64 "`{}`.".format(*self._default_fix)
65 ),
66 obj=self,
67 id="fields.invalid_choice_mixin_default",
68 warning=True,
69 )
70 ]
71 else:
72 return []
73
74 def preflight(self, **kwargs: Any) -> list[PreflightResult]:
75 errors = super().preflight(**kwargs)
76 errors.extend(self._check_default())
77 errors.extend(self._check_supported())
78 return errors
79
80 def _check_supported(self) -> list[PreflightResult]:
81 # PostgreSQL always supports JSONField (native JSONB type).
82 return []
83
84 def deconstruct(self) -> tuple[str | None, str, list[Any], dict[str, Any]]:
85 name, path, args, kwargs = super().deconstruct()
86 if self.encoder is not None:
87 kwargs["encoder"] = self.encoder
88 if self.decoder is not None:
89 kwargs["decoder"] = self.decoder
90 return name, path, args, kwargs
91
92 def from_db_value(
93 self, value: Any, expression: Any, connection: DatabaseWrapper
94 ) -> Any:
95 if value is None:
96 return value
97 # KeyTransform may extract non-string values directly.
98 if isinstance(expression, KeyTransform) and not isinstance(value, str):
99 return value
100 try:
101 return json.loads(value, cls=self.decoder)
102 except json.JSONDecodeError:
103 return value
104
105 def get_internal_type(self) -> str:
106 return "JSONField"
107
108 def get_db_prep_value(
109 self, value: Any, connection: DatabaseWrapper, prepared: bool = False
110 ) -> Any:
111 if isinstance(value, expressions.Value) and isinstance(
112 value.output_field, JSONField
113 ):
114 value = value.value
115 elif hasattr(value, "as_sql"):
116 return value
117 return adapt_json_value(value, self.encoder)
118
119 def get_db_prep_save(self, value: Any, connection: DatabaseWrapper) -> Any:
120 if value is None:
121 return value
122 return self.get_db_prep_value(value, connection)
123
124 def get_transform(
125 self, lookup_name: str
126 ) -> type[Transform] | Callable[..., Any] | None:
127 # Always returns a transform (never None in practice)
128 transform = super().get_transform(lookup_name)
129 if transform:
130 return transform
131 return KeyTransformFactory(lookup_name)
132
133 def validate(self, value: Any, model_instance: Any) -> None:
134 super().validate(value, model_instance)
135 try:
136 json.dumps(value, cls=self.encoder)
137 except TypeError:
138 raise exceptions.ValidationError(
139 self.error_messages["invalid"],
140 code="invalid",
141 params={"value": value},
142 )
143
144 def value_to_string(self, obj: Any) -> Any:
145 return self.value_from_object(obj)
146
147
148class DataContains(FieldGetDbPrepValueMixin, OperatorLookup):
149 lookup_name = "contains"
150 # PostgreSQL @> operator checks if left JSON contains right JSON.
151 operator = "@>"
152
153
154class ContainedBy(FieldGetDbPrepValueMixin, OperatorLookup):
155 lookup_name = "contained_by"
156 # PostgreSQL <@ operator checks if left JSON is contained by right JSON.
157 operator = "<@"
158
159
160class HasKeyLookup(OperatorLookup):
161 """Lookup for checking if a JSON field has a key."""
162
163 logical_operator: str | None = None
164
165 def as_sql(
166 self, compiler: SQLCompiler, connection: DatabaseWrapper
167 ) -> tuple[str, tuple[Any, ...]]:
168 # Handle KeyTransform on RHS by expanding it into LHS chain.
169 if isinstance(self.rhs, KeyTransform):
170 *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
171 for key in rhs_key_transforms[:-1]:
172 self.lhs = KeyTransform(key, self.lhs)
173 self.rhs = rhs_key_transforms[-1]
174 return super().as_sql(compiler, connection)
175
176
177class HasKey(HasKeyLookup):
178 lookup_name = "has_key"
179 # PostgreSQL ? operator checks if key exists.
180 operator = "?"
181 prepare_rhs = False
182
183
184class HasKeys(HasKeyLookup):
185 lookup_name = "has_keys"
186 # PostgreSQL ?& operator checks if all keys exist.
187 operator = "?&"
188 logical_operator = " AND "
189
190 def get_prep_lookup(self) -> list[str]:
191 return [str(item) for item in self.rhs]
192
193
194class HasAnyKeys(HasKeys):
195 lookup_name = "has_any_keys"
196 # PostgreSQL ?| operator checks if any key exists.
197 operator = "?|"
198 logical_operator = " OR "
199
200
201class JSONExact(lookups.Exact):
202 can_use_none_as_rhs = True
203
204 def process_rhs(
205 self, compiler: SQLCompiler, connection: DatabaseWrapper
206 ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
207 rhs, rhs_params = super().process_rhs(compiler, connection)
208 if isinstance(rhs, str):
209 # Treat None lookup values as null.
210 if rhs == "%s" and rhs_params == [None]:
211 rhs_params = ["null"]
212 return rhs, rhs_params
213 else:
214 return rhs, rhs_params
215
216
217class JSONIContains(lookups.IContains):
218 pass
219
220
221JSONField.register_lookup(DataContains)
222JSONField.register_lookup(ContainedBy)
223JSONField.register_lookup(HasKey)
224JSONField.register_lookup(HasKeys)
225JSONField.register_lookup(HasAnyKeys)
226JSONField.register_lookup(JSONExact)
227JSONField.register_lookup(JSONIContains)
228
229
230class KeyTransform(Transform):
231 # PostgreSQL -> operator extracts JSON object field as JSON.
232 operator = "->"
233 # PostgreSQL #> operator extracts nested JSON path as JSON.
234 nested_operator = "#>"
235
236 def __init__(self, key_name: str, *args: Any, **kwargs: Any):
237 super().__init__(*args, **kwargs)
238 self.key_name = str(key_name)
239
240 def preprocess_lhs(
241 self, compiler: SQLCompiler, connection: DatabaseWrapper
242 ) -> tuple[str, tuple[Any, ...], list[str]]:
243 key_transforms = [self.key_name]
244 previous = self.lhs
245 while isinstance(previous, KeyTransform):
246 key_transforms.insert(0, previous.key_name)
247 previous = previous.lhs
248 lhs, params = compiler.compile(previous)
249 return lhs, params, key_transforms
250
251 def as_sql(
252 self,
253 compiler: SQLCompiler,
254 connection: DatabaseWrapper,
255 function: str | None = None,
256 template: str | None = None,
257 arg_joiner: str | None = None,
258 **extra_context: Any,
259 ) -> tuple[str, list[Any]]:
260 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
261 if len(key_transforms) > 1:
262 sql = f"({lhs} {self.nested_operator} %s)"
263 return sql, list(params) + [key_transforms]
264 try:
265 lookup = int(self.key_name)
266 except ValueError:
267 lookup = self.key_name
268 return f"({lhs} {self.operator} %s)", list(params) + [lookup]
269
270
271class KeyTextTransform(KeyTransform):
272 # PostgreSQL ->> operator extracts JSON object field as text.
273 operator = "->>"
274 # PostgreSQL #>> operator extracts nested JSON path as text.
275 nested_operator = "#>>"
276 output_field = TextField()
277
278 @classmethod
279 def from_lookup(cls, lookup: str) -> Any:
280 transform, *keys = lookup.split(LOOKUP_SEP)
281 if not keys:
282 raise ValueError("Lookup must contain key or index transforms.")
283 for key in keys:
284 transform = cls(key, transform)
285 return transform
286
287
288KT = KeyTextTransform.from_lookup
289
290
291class KeyTransformTextLookupMixin(Lookup):
292 """
293 Mixin for lookups expecting text LHS from a JSONField key lookup.
294 Uses the ->> operator to extract JSON values as text.
295 """
296
297 def __init__(self, key_transform: Any, *args: Any, **kwargs: Any):
298 if not isinstance(key_transform, KeyTransform):
299 raise TypeError(
300 "Transform should be an instance of KeyTransform in order to "
301 "use this lookup."
302 )
303 key_text_transform = KeyTextTransform(
304 key_transform.key_name,
305 *key_transform.source_expressions,
306 **key_transform.extra,
307 )
308 super().__init__(key_text_transform, *args, **kwargs)
309
310
311class KeyTransformIsNull(lookups.IsNull):
312 # key__isnull=False is the same as has_key='key'
313 pass
314
315
316class KeyTransformIn(lookups.In):
317 def resolve_expression_parameter(
318 self,
319 compiler: SQLCompiler,
320 connection: DatabaseWrapper,
321 sql: str,
322 param: Any,
323 ) -> tuple[str, list[Any]]:
324 sql, params = super().resolve_expression_parameter(
325 compiler,
326 connection,
327 sql,
328 param,
329 )
330 return sql, list(params)
331
332
333class KeyTransformExact(JSONExact):
334 def process_rhs(
335 self, compiler: SQLCompiler, connection: DatabaseWrapper
336 ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
337 if isinstance(self.rhs, KeyTransform):
338 return super(lookups.Exact, self).process_rhs(compiler, connection)
339 return super().process_rhs(compiler, connection)
340
341
342class KeyTransformIExact(KeyTransformTextLookupMixin, lookups.IExact):
343 pass
344
345
346class KeyTransformIContains(KeyTransformTextLookupMixin, lookups.IContains):
347 pass
348
349
350class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
351 pass
352
353
354class KeyTransformIStartsWith(KeyTransformTextLookupMixin, lookups.IStartsWith):
355 pass
356
357
358class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
359 pass
360
361
362class KeyTransformIEndsWith(KeyTransformTextLookupMixin, lookups.IEndsWith):
363 pass
364
365
366class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
367 pass
368
369
370class KeyTransformIRegex(KeyTransformTextLookupMixin, lookups.IRegex):
371 pass
372
373
374class KeyTransformLt(lookups.LessThan):
375 pass
376
377
378class KeyTransformLte(lookups.LessThanOrEqual):
379 pass
380
381
382class KeyTransformGt(lookups.GreaterThan):
383 pass
384
385
386class KeyTransformGte(lookups.GreaterThanOrEqual):
387 pass
388
389
390KeyTransform.register_lookup(KeyTransformIn)
391KeyTransform.register_lookup(KeyTransformExact)
392KeyTransform.register_lookup(KeyTransformIExact)
393KeyTransform.register_lookup(KeyTransformIsNull)
394KeyTransform.register_lookup(KeyTransformIContains)
395KeyTransform.register_lookup(KeyTransformStartsWith)
396KeyTransform.register_lookup(KeyTransformIStartsWith)
397KeyTransform.register_lookup(KeyTransformEndsWith)
398KeyTransform.register_lookup(KeyTransformIEndsWith)
399KeyTransform.register_lookup(KeyTransformRegex)
400KeyTransform.register_lookup(KeyTransformIRegex)
401
402KeyTransform.register_lookup(KeyTransformLt)
403KeyTransform.register_lookup(KeyTransformLte)
404KeyTransform.register_lookup(KeyTransformGt)
405KeyTransform.register_lookup(KeyTransformGte)
406
407
408class KeyTransformFactory:
409 def __init__(self, key_name: str):
410 self.key_name = key_name
411
412 def __call__(self, *args: Any, **kwargs: Any) -> KeyTransform:
413 return KeyTransform(self.key_name, *args, **kwargs)