1from __future__ import annotations
2
3import json
4from collections.abc import Callable, Sequence
5from typing import TYPE_CHECKING, Any
6
7from plain import exceptions
8from plain.postgres import expressions, lookups
9from plain.postgres.constants import LOOKUP_SEP
10from plain.postgres.dialect import adapt_json_value
11from plain.postgres.fields import TextField
12from plain.postgres.fields.base import NOT_PROVIDED
13from plain.postgres.lookups import (
14 FieldGetDbPrepValueMixin,
15 Lookup,
16 OperatorLookup,
17 Transform,
18)
19
20from .base import DefaultableField
21
22if TYPE_CHECKING:
23 from plain.postgres.connection import DatabaseConnection
24 from plain.postgres.sql.compiler import SQLCompiler
25
26__all__ = ["JSONField"]
27
28
29class JSONField(DefaultableField):
30 db_type_sql = "jsonb"
31 empty_strings_allowed = False
32
33 def __init__(
34 self,
35 *,
36 encoder: type[json.JSONEncoder] | None = None,
37 decoder: type[json.JSONDecoder] | None = None,
38 required: bool = True,
39 allow_null: bool = False,
40 default: Any = NOT_PROVIDED,
41 validators: Sequence[Callable[..., Any]] = (),
42 ):
43 if encoder and not callable(encoder):
44 raise ValueError("The encoder parameter must be a callable object.")
45 if decoder and not callable(decoder):
46 raise ValueError("The decoder parameter must be a callable object.")
47 self.encoder = encoder
48 self.decoder = decoder
49 super().__init__(
50 required=required,
51 allow_null=allow_null,
52 default=default,
53 validators=validators,
54 )
55
56 def deconstruct(self) -> tuple[str | None, str, list[Any], dict[str, Any]]:
57 name, path, args, kwargs = super().deconstruct()
58 if self.encoder is not None:
59 kwargs["encoder"] = self.encoder
60 if self.decoder is not None:
61 kwargs["decoder"] = self.decoder
62 return name, path, args, kwargs
63
64 def from_db_value(
65 self, value: Any, expression: Any, connection: DatabaseConnection
66 ) -> Any:
67 if value is None:
68 return value
69 # KeyTransform may extract non-string values directly.
70 if isinstance(expression, KeyTransform) and not isinstance(value, str):
71 return value
72 try:
73 return json.loads(value, cls=self.decoder)
74 except json.JSONDecodeError:
75 return value
76
77 def get_db_prep_value(
78 self, value: Any, connection: DatabaseConnection, prepared: bool = False
79 ) -> Any:
80 if isinstance(value, expressions.Value) and isinstance(
81 value.output_field, JSONField
82 ):
83 value = value.value
84 elif hasattr(value, "as_sql"):
85 return value
86 return adapt_json_value(value, self.encoder)
87
88 def get_db_prep_save(self, value: Any, connection: DatabaseConnection) -> Any:
89 if value is None:
90 return value
91 return self.get_db_prep_value(value, connection)
92
93 def get_transform(
94 self, lookup_name: str
95 ) -> type[Transform] | Callable[..., Any] | None:
96 # Always returns a transform (never None in practice)
97 transform = super().get_transform(lookup_name)
98 if transform:
99 return transform
100 return KeyTransformFactory(lookup_name)
101
102 def validate(self, value: Any, model_instance: Any) -> None:
103 super().validate(value, model_instance)
104 try:
105 json.dumps(value, cls=self.encoder)
106 except TypeError:
107 raise exceptions.ValidationError(
108 "Value must be valid JSON.",
109 code="invalid",
110 params={"value": value},
111 )
112
113
114class DataContains(FieldGetDbPrepValueMixin, OperatorLookup):
115 lookup_name = "contains"
116 # PostgreSQL @> operator checks if left JSON contains right JSON.
117 operator = "@>"
118
119
120class ContainedBy(FieldGetDbPrepValueMixin, OperatorLookup):
121 lookup_name = "contained_by"
122 # PostgreSQL <@ operator checks if left JSON is contained by right JSON.
123 operator = "<@"
124
125
126class HasKeyLookup(OperatorLookup):
127 """Lookup for checking if a JSON field has a key."""
128
129 logical_operator: str | None = None
130
131 def as_sql(
132 self, compiler: SQLCompiler, connection: DatabaseConnection
133 ) -> tuple[str, tuple[Any, ...]]:
134 # Handle KeyTransform on RHS by expanding it into LHS chain.
135 if isinstance(self.rhs, KeyTransform):
136 *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
137 for key in rhs_key_transforms[:-1]:
138 self.lhs = KeyTransform(key, self.lhs)
139 self.rhs = rhs_key_transforms[-1]
140 return super().as_sql(compiler, connection)
141
142
143class HasKey(HasKeyLookup):
144 lookup_name = "has_key"
145 # PostgreSQL ? operator checks if key exists.
146 operator = "?"
147 prepare_rhs = False
148
149
150class HasKeys(HasKeyLookup):
151 lookup_name = "has_keys"
152 # PostgreSQL ?& operator checks if all keys exist.
153 operator = "?&"
154 logical_operator = " AND "
155
156 def get_prep_lookup(self) -> list[str]:
157 return [str(item) for item in self.rhs]
158
159
160class HasAnyKeys(HasKeys):
161 lookup_name = "has_any_keys"
162 # PostgreSQL ?| operator checks if any key exists.
163 operator = "?|"
164 logical_operator = " OR "
165
166
167class JSONExact(lookups.Exact):
168 can_use_none_as_rhs = True
169
170 def process_rhs(
171 self, compiler: SQLCompiler, connection: DatabaseConnection
172 ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
173 rhs, rhs_params = super().process_rhs(compiler, connection)
174 if isinstance(rhs, str):
175 # Treat None lookup values as null.
176 if rhs == "%s" and rhs_params == [None]:
177 rhs_params = ["null"]
178 return rhs, rhs_params
179 else:
180 return rhs, rhs_params
181
182
183class JSONIContains(lookups.IContains):
184 pass
185
186
187JSONField.register_lookup(DataContains)
188JSONField.register_lookup(ContainedBy)
189JSONField.register_lookup(HasKey)
190JSONField.register_lookup(HasKeys)
191JSONField.register_lookup(HasAnyKeys)
192JSONField.register_lookup(JSONExact)
193JSONField.register_lookup(JSONIContains)
194
195
196class KeyTransform(Transform):
197 # PostgreSQL -> operator extracts JSON object field as JSON.
198 operator = "->"
199 # PostgreSQL #> operator extracts nested JSON path as JSON.
200 nested_operator = "#>"
201
202 def __init__(self, key_name: str, *args: Any, **kwargs: Any):
203 super().__init__(*args, **kwargs)
204 self.key_name = str(key_name)
205
206 def preprocess_lhs(
207 self, compiler: SQLCompiler, connection: DatabaseConnection
208 ) -> tuple[str, tuple[Any, ...], list[str]]:
209 key_transforms = [self.key_name]
210 previous = self.lhs
211 while isinstance(previous, KeyTransform):
212 key_transforms.insert(0, previous.key_name)
213 previous = previous.lhs
214 lhs, params = compiler.compile(previous)
215 return lhs, params, key_transforms
216
217 def as_sql(
218 self,
219 compiler: SQLCompiler,
220 connection: DatabaseConnection,
221 function: str | None = None,
222 template: str | None = None,
223 arg_joiner: str | None = None,
224 **extra_context: Any,
225 ) -> tuple[str, list[Any]]:
226 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
227 if len(key_transforms) > 1:
228 sql = f"({lhs} {self.nested_operator} %s)"
229 return sql, list(params) + [key_transforms]
230 try:
231 lookup = int(self.key_name)
232 except ValueError:
233 lookup = self.key_name
234 return f"({lhs} {self.operator} %s)", list(params) + [lookup]
235
236
237class KeyTextTransform(KeyTransform):
238 # PostgreSQL ->> operator extracts JSON object field as text.
239 operator = "->>"
240 # PostgreSQL #>> operator extracts nested JSON path as text.
241 nested_operator = "#>>"
242 output_field = TextField()
243
244 @classmethod
245 def from_lookup(cls, lookup: str) -> Any:
246 transform, *keys = lookup.split(LOOKUP_SEP)
247 if not keys:
248 raise ValueError("Lookup must contain key or index transforms.")
249 for key in keys:
250 transform = cls(key, transform)
251 return transform
252
253
254KT = KeyTextTransform.from_lookup
255
256
257class KeyTransformTextLookupMixin(Lookup):
258 """
259 Mixin for lookups expecting text LHS from a JSONField key lookup.
260 Uses the ->> operator to extract JSON values as text.
261 """
262
263 def __init__(self, key_transform: Any, *args: Any, **kwargs: Any):
264 if not isinstance(key_transform, KeyTransform):
265 raise TypeError(
266 "Transform should be an instance of KeyTransform in order to "
267 "use this lookup."
268 )
269 key_text_transform = KeyTextTransform(
270 key_transform.key_name,
271 *key_transform.source_expressions,
272 **key_transform.extra,
273 )
274 super().__init__(key_text_transform, *args, **kwargs)
275
276
277class KeyTransformIsNull(lookups.IsNull):
278 # key__isnull=False is the same as has_key='key'
279 pass
280
281
282class KeyTransformIn(lookups.In):
283 def resolve_expression_parameter(
284 self,
285 compiler: SQLCompiler,
286 connection: DatabaseConnection,
287 sql: str,
288 param: Any,
289 ) -> tuple[str, list[Any]]:
290 sql, params = super().resolve_expression_parameter(
291 compiler,
292 connection,
293 sql,
294 param,
295 )
296 return sql, list(params)
297
298
299class KeyTransformExact(JSONExact):
300 def process_rhs(
301 self, compiler: SQLCompiler, connection: DatabaseConnection
302 ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
303 if isinstance(self.rhs, KeyTransform):
304 return super(lookups.Exact, self).process_rhs(compiler, connection)
305 return super().process_rhs(compiler, connection)
306
307
308class KeyTransformIExact(KeyTransformTextLookupMixin, lookups.IExact):
309 pass
310
311
312class KeyTransformIContains(KeyTransformTextLookupMixin, lookups.IContains):
313 pass
314
315
316class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
317 pass
318
319
320class KeyTransformIStartsWith(KeyTransformTextLookupMixin, lookups.IStartsWith):
321 pass
322
323
324class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
325 pass
326
327
328class KeyTransformIEndsWith(KeyTransformTextLookupMixin, lookups.IEndsWith):
329 pass
330
331
332class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
333 pass
334
335
336class KeyTransformIRegex(KeyTransformTextLookupMixin, lookups.IRegex):
337 pass
338
339
340class KeyTransformLt(lookups.LessThan):
341 pass
342
343
344class KeyTransformLte(lookups.LessThanOrEqual):
345 pass
346
347
348class KeyTransformGt(lookups.GreaterThan):
349 pass
350
351
352class KeyTransformGte(lookups.GreaterThanOrEqual):
353 pass
354
355
356KeyTransform.register_lookup(KeyTransformIn)
357KeyTransform.register_lookup(KeyTransformExact)
358KeyTransform.register_lookup(KeyTransformIExact)
359KeyTransform.register_lookup(KeyTransformIsNull)
360KeyTransform.register_lookup(KeyTransformIContains)
361KeyTransform.register_lookup(KeyTransformStartsWith)
362KeyTransform.register_lookup(KeyTransformIStartsWith)
363KeyTransform.register_lookup(KeyTransformEndsWith)
364KeyTransform.register_lookup(KeyTransformIEndsWith)
365KeyTransform.register_lookup(KeyTransformRegex)
366KeyTransform.register_lookup(KeyTransformIRegex)
367
368KeyTransform.register_lookup(KeyTransformLt)
369KeyTransform.register_lookup(KeyTransformLte)
370KeyTransform.register_lookup(KeyTransformGt)
371KeyTransform.register_lookup(KeyTransformGte)
372
373
374class KeyTransformFactory:
375 def __init__(self, key_name: str):
376 self.key_name = key_name
377
378 def __call__(self, *args: Any, **kwargs: Any) -> KeyTransform:
379 return KeyTransform(self.key_name, *args, **kwargs)