v0.146.0
  1from __future__ import annotations
  2
  3import json
  4from collections.abc import Callable, Sequence
  5from typing import TYPE_CHECKING, Any
  6
  7from plain import exceptions
  8from plain.postgres import expressions, lookups
  9from plain.postgres.constants import LOOKUP_SEP
 10from plain.postgres.dialect import adapt_json_value
 11from plain.postgres.fields import TextField
 12from plain.postgres.fields.base import NOT_PROVIDED
 13from plain.postgres.lookups import (
 14    FieldGetDbPrepValueMixin,
 15    Lookup,
 16    OperatorLookup,
 17    Transform,
 18)
 19
 20from .base import DefaultableField
 21
 22if TYPE_CHECKING:
 23    from plain.postgres.connection import DatabaseConnection
 24    from plain.postgres.sql.compiler import SQLCompiler
 25
 26__all__ = ["JSONField"]
 27
 28
 29class JSONField(DefaultableField):
 30    db_type_sql = "jsonb"
 31    empty_strings_allowed = False
 32
 33    def __init__(
 34        self,
 35        *,
 36        encoder: type[json.JSONEncoder] | None = None,
 37        decoder: type[json.JSONDecoder] | None = None,
 38        required: bool = True,
 39        allow_null: bool = False,
 40        default: Any = NOT_PROVIDED,
 41        validators: Sequence[Callable[..., Any]] = (),
 42    ):
 43        if encoder and not callable(encoder):
 44            raise ValueError("The encoder parameter must be a callable object.")
 45        if decoder and not callable(decoder):
 46            raise ValueError("The decoder parameter must be a callable object.")
 47        self.encoder = encoder
 48        self.decoder = decoder
 49        super().__init__(
 50            required=required,
 51            allow_null=allow_null,
 52            default=default,
 53            validators=validators,
 54        )
 55
 56    def deconstruct(self) -> tuple[str | None, str, list[Any], dict[str, Any]]:
 57        name, path, args, kwargs = super().deconstruct()
 58        if self.encoder is not None:
 59            kwargs["encoder"] = self.encoder
 60        if self.decoder is not None:
 61            kwargs["decoder"] = self.decoder
 62        return name, path, args, kwargs
 63
 64    def from_db_value(
 65        self, value: Any, expression: Any, connection: DatabaseConnection
 66    ) -> Any:
 67        if value is None:
 68            return value
 69        # KeyTransform may extract non-string values directly.
 70        if isinstance(expression, KeyTransform) and not isinstance(value, str):
 71            return value
 72        try:
 73            return json.loads(value, cls=self.decoder)
 74        except json.JSONDecodeError:
 75            return value
 76
 77    def get_db_prep_value(
 78        self, value: Any, connection: DatabaseConnection, prepared: bool = False
 79    ) -> Any:
 80        if isinstance(value, expressions.Value) and isinstance(
 81            value.output_field, JSONField
 82        ):
 83            value = value.value
 84        elif hasattr(value, "as_sql"):
 85            return value
 86        return adapt_json_value(value, self.encoder)
 87
 88    def get_db_prep_save(self, value: Any, connection: DatabaseConnection) -> Any:
 89        if value is None:
 90            return value
 91        return self.get_db_prep_value(value, connection)
 92
 93    def get_transform(
 94        self, lookup_name: str
 95    ) -> type[Transform] | Callable[..., Any] | None:
 96        # Always returns a transform (never None in practice)
 97        transform = super().get_transform(lookup_name)
 98        if transform:
 99            return transform
100        return KeyTransformFactory(lookup_name)
101
102    def validate(self, value: Any, model_instance: Any) -> None:
103        super().validate(value, model_instance)
104        try:
105            json.dumps(value, cls=self.encoder)
106        except TypeError:
107            raise exceptions.ValidationError(
108                "Value must be valid JSON.",
109                code="invalid",
110                params={"value": value},
111            )
112
113
114class DataContains(FieldGetDbPrepValueMixin, OperatorLookup):
115    lookup_name = "contains"
116    # PostgreSQL @> operator checks if left JSON contains right JSON.
117    operator = "@>"
118
119
120class ContainedBy(FieldGetDbPrepValueMixin, OperatorLookup):
121    lookup_name = "contained_by"
122    # PostgreSQL <@ operator checks if left JSON is contained by right JSON.
123    operator = "<@"
124
125
126class HasKeyLookup(OperatorLookup):
127    """Lookup for checking if a JSON field has a key."""
128
129    logical_operator: str | None = None
130
131    def as_sql(
132        self, compiler: SQLCompiler, connection: DatabaseConnection
133    ) -> tuple[str, tuple[Any, ...]]:
134        # Handle KeyTransform on RHS by expanding it into LHS chain.
135        if isinstance(self.rhs, KeyTransform):
136            *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
137            for key in rhs_key_transforms[:-1]:
138                self.lhs = KeyTransform(key, self.lhs)
139            self.rhs = rhs_key_transforms[-1]
140        return super().as_sql(compiler, connection)
141
142
143class HasKey(HasKeyLookup):
144    lookup_name = "has_key"
145    # PostgreSQL ? operator checks if key exists.
146    operator = "?"
147    prepare_rhs = False
148
149
150class HasKeys(HasKeyLookup):
151    lookup_name = "has_keys"
152    # PostgreSQL ?& operator checks if all keys exist.
153    operator = "?&"
154    logical_operator = " AND "
155
156    def get_prep_lookup(self) -> list[str]:
157        return [str(item) for item in self.rhs]
158
159
160class HasAnyKeys(HasKeys):
161    lookup_name = "has_any_keys"
162    # PostgreSQL ?| operator checks if any key exists.
163    operator = "?|"
164    logical_operator = " OR "
165
166
167class JSONExact(lookups.Exact):
168    can_use_none_as_rhs = True
169
170    def process_rhs(
171        self, compiler: SQLCompiler, connection: DatabaseConnection
172    ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
173        rhs, rhs_params = super().process_rhs(compiler, connection)
174        if isinstance(rhs, str):
175            # Treat None lookup values as null.
176            if rhs == "%s" and rhs_params == [None]:
177                rhs_params = ["null"]
178            return rhs, rhs_params
179        else:
180            return rhs, rhs_params
181
182
183class JSONIContains(lookups.IContains):
184    pass
185
186
187JSONField.register_lookup(DataContains)
188JSONField.register_lookup(ContainedBy)
189JSONField.register_lookup(HasKey)
190JSONField.register_lookup(HasKeys)
191JSONField.register_lookup(HasAnyKeys)
192JSONField.register_lookup(JSONExact)
193JSONField.register_lookup(JSONIContains)
194
195
196class KeyTransform(Transform):
197    # PostgreSQL -> operator extracts JSON object field as JSON.
198    operator = "->"
199    # PostgreSQL #> operator extracts nested JSON path as JSON.
200    nested_operator = "#>"
201
202    def __init__(self, key_name: str, *args: Any, **kwargs: Any):
203        super().__init__(*args, **kwargs)
204        self.key_name = str(key_name)
205
206    def preprocess_lhs(
207        self, compiler: SQLCompiler, connection: DatabaseConnection
208    ) -> tuple[str, tuple[Any, ...], list[str]]:
209        key_transforms = [self.key_name]
210        previous = self.lhs
211        while isinstance(previous, KeyTransform):
212            key_transforms.insert(0, previous.key_name)
213            previous = previous.lhs
214        lhs, params = compiler.compile(previous)
215        return lhs, params, key_transforms
216
217    def as_sql(
218        self,
219        compiler: SQLCompiler,
220        connection: DatabaseConnection,
221        function: str | None = None,
222        template: str | None = None,
223        arg_joiner: str | None = None,
224        **extra_context: Any,
225    ) -> tuple[str, list[Any]]:
226        lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
227        if len(key_transforms) > 1:
228            sql = f"({lhs} {self.nested_operator} %s)"
229            return sql, list(params) + [key_transforms]
230        try:
231            lookup = int(self.key_name)
232        except ValueError:
233            lookup = self.key_name
234        return f"({lhs} {self.operator} %s)", list(params) + [lookup]
235
236
237class KeyTextTransform(KeyTransform):
238    # PostgreSQL ->> operator extracts JSON object field as text.
239    operator = "->>"
240    # PostgreSQL #>> operator extracts nested JSON path as text.
241    nested_operator = "#>>"
242    output_field = TextField()
243
244    @classmethod
245    def from_lookup(cls, lookup: str) -> Any:
246        transform, *keys = lookup.split(LOOKUP_SEP)
247        if not keys:
248            raise ValueError("Lookup must contain key or index transforms.")
249        for key in keys:
250            transform = cls(key, transform)
251        return transform
252
253
254KT = KeyTextTransform.from_lookup
255
256
257class KeyTransformTextLookupMixin(Lookup):
258    """
259    Mixin for lookups expecting text LHS from a JSONField key lookup.
260    Uses the ->> operator to extract JSON values as text.
261    """
262
263    def __init__(self, key_transform: Any, *args: Any, **kwargs: Any):
264        if not isinstance(key_transform, KeyTransform):
265            raise TypeError(
266                "Transform should be an instance of KeyTransform in order to "
267                "use this lookup."
268            )
269        key_text_transform = KeyTextTransform(
270            key_transform.key_name,
271            *key_transform.source_expressions,
272            **key_transform.extra,
273        )
274        super().__init__(key_text_transform, *args, **kwargs)
275
276
277class KeyTransformIsNull(lookups.IsNull):
278    # key__isnull=False is the same as has_key='key'
279    pass
280
281
282class KeyTransformIn(lookups.In):
283    def resolve_expression_parameter(
284        self,
285        compiler: SQLCompiler,
286        connection: DatabaseConnection,
287        sql: str,
288        param: Any,
289    ) -> tuple[str, list[Any]]:
290        sql, params = super().resolve_expression_parameter(
291            compiler,
292            connection,
293            sql,
294            param,
295        )
296        return sql, list(params)
297
298
299class KeyTransformExact(JSONExact):
300    def process_rhs(
301        self, compiler: SQLCompiler, connection: DatabaseConnection
302    ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
303        if isinstance(self.rhs, KeyTransform):
304            return super(lookups.Exact, self).process_rhs(compiler, connection)
305        return super().process_rhs(compiler, connection)
306
307
308class KeyTransformIExact(KeyTransformTextLookupMixin, lookups.IExact):
309    pass
310
311
312class KeyTransformIContains(KeyTransformTextLookupMixin, lookups.IContains):
313    pass
314
315
316class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
317    pass
318
319
320class KeyTransformIStartsWith(KeyTransformTextLookupMixin, lookups.IStartsWith):
321    pass
322
323
324class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
325    pass
326
327
328class KeyTransformIEndsWith(KeyTransformTextLookupMixin, lookups.IEndsWith):
329    pass
330
331
332class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
333    pass
334
335
336class KeyTransformIRegex(KeyTransformTextLookupMixin, lookups.IRegex):
337    pass
338
339
340class KeyTransformLt(lookups.LessThan):
341    pass
342
343
344class KeyTransformLte(lookups.LessThanOrEqual):
345    pass
346
347
348class KeyTransformGt(lookups.GreaterThan):
349    pass
350
351
352class KeyTransformGte(lookups.GreaterThanOrEqual):
353    pass
354
355
356KeyTransform.register_lookup(KeyTransformIn)
357KeyTransform.register_lookup(KeyTransformExact)
358KeyTransform.register_lookup(KeyTransformIExact)
359KeyTransform.register_lookup(KeyTransformIsNull)
360KeyTransform.register_lookup(KeyTransformIContains)
361KeyTransform.register_lookup(KeyTransformStartsWith)
362KeyTransform.register_lookup(KeyTransformIStartsWith)
363KeyTransform.register_lookup(KeyTransformEndsWith)
364KeyTransform.register_lookup(KeyTransformIEndsWith)
365KeyTransform.register_lookup(KeyTransformRegex)
366KeyTransform.register_lookup(KeyTransformIRegex)
367
368KeyTransform.register_lookup(KeyTransformLt)
369KeyTransform.register_lookup(KeyTransformLte)
370KeyTransform.register_lookup(KeyTransformGt)
371KeyTransform.register_lookup(KeyTransformGte)
372
373
374class KeyTransformFactory:
375    def __init__(self, key_name: str):
376        self.key_name = key_name
377
378    def __call__(self, *args: Any, **kwargs: Any) -> KeyTransform:
379        return KeyTransform(self.key_name, *args, **kwargs)