1from __future__ import annotations
  2
  3import json
  4from collections.abc import Callable
  5from typing import TYPE_CHECKING, Any
  6
  7from plain import exceptions, preflight
  8from plain.models import expressions, lookups
  9from plain.models.constants import LOOKUP_SEP
 10from plain.models.fields import TextField
 11from plain.models.lookups import (
 12    FieldGetDbPrepValueMixin,
 13    Lookup,
 14    OperatorLookup,
 15    Transform,
 16)
 17from plain.models.postgres.sql import adapt_json_value
 18
 19from . import Field
 20
 21if TYPE_CHECKING:
 22    from plain.models.postgres.wrapper import DatabaseWrapper
 23    from plain.models.sql.compiler import SQLCompiler
 24    from plain.preflight.results import PreflightResult
 25
 26__all__ = ["JSONField"]
 27
 28
 29class JSONField(Field):
 30    empty_strings_allowed = False
 31    description = "A JSON object"
 32    default_error_messages = {
 33        "invalid": "Value must be valid JSON.",
 34    }
 35    _default_fix = ("dict", "{}")
 36
 37    def __init__(
 38        self,
 39        *,
 40        encoder: type[json.JSONEncoder] | None = None,
 41        decoder: type[json.JSONDecoder] | None = None,
 42        **kwargs: Any,
 43    ):
 44        if encoder and not callable(encoder):
 45            raise ValueError("The encoder parameter must be a callable object.")
 46        if decoder and not callable(decoder):
 47            raise ValueError("The decoder parameter must be a callable object.")
 48        self.encoder = encoder
 49        self.decoder = decoder
 50        super().__init__(**kwargs)
 51
 52    def _check_default(self) -> list[PreflightResult]:
 53        if (
 54            self.has_default()
 55            and self.default is not None
 56            and not callable(self.default)
 57        ):
 58            return [
 59                preflight.PreflightResult(
 60                    fix=(
 61                        f"{self.__class__.__name__} default should be a callable instead of an instance "
 62                        "so that it's not shared between all field instances. "
 63                        "Use a callable instead, e.g., use `{}` instead of "
 64                        "`{}`.".format(*self._default_fix)
 65                    ),
 66                    obj=self,
 67                    id="fields.invalid_choice_mixin_default",
 68                    warning=True,
 69                )
 70            ]
 71        else:
 72            return []
 73
 74    def preflight(self, **kwargs: Any) -> list[PreflightResult]:
 75        errors = super().preflight(**kwargs)
 76        errors.extend(self._check_default())
 77        errors.extend(self._check_supported())
 78        return errors
 79
 80    def _check_supported(self) -> list[PreflightResult]:
 81        # PostgreSQL always supports JSONField (native JSONB type).
 82        return []
 83
 84    def deconstruct(self) -> tuple[str | None, str, list[Any], dict[str, Any]]:
 85        name, path, args, kwargs = super().deconstruct()
 86        if self.encoder is not None:
 87            kwargs["encoder"] = self.encoder
 88        if self.decoder is not None:
 89            kwargs["decoder"] = self.decoder
 90        return name, path, args, kwargs
 91
 92    def from_db_value(
 93        self, value: Any, expression: Any, connection: DatabaseWrapper
 94    ) -> Any:
 95        if value is None:
 96            return value
 97        # KeyTransform may extract non-string values directly.
 98        if isinstance(expression, KeyTransform) and not isinstance(value, str):
 99            return value
100        try:
101            return json.loads(value, cls=self.decoder)
102        except json.JSONDecodeError:
103            return value
104
105    def get_internal_type(self) -> str:
106        return "JSONField"
107
108    def get_db_prep_value(
109        self, value: Any, connection: DatabaseWrapper, prepared: bool = False
110    ) -> Any:
111        if isinstance(value, expressions.Value) and isinstance(
112            value.output_field, JSONField
113        ):
114            value = value.value
115        elif hasattr(value, "as_sql"):
116            return value
117        return adapt_json_value(value, self.encoder)
118
119    def get_db_prep_save(self, value: Any, connection: DatabaseWrapper) -> Any:
120        if value is None:
121            return value
122        return self.get_db_prep_value(value, connection)
123
124    def get_transform(
125        self, lookup_name: str
126    ) -> type[Transform] | Callable[..., Any] | None:
127        # Always returns a transform (never None in practice)
128        transform = super().get_transform(lookup_name)
129        if transform:
130            return transform
131        return KeyTransformFactory(lookup_name)
132
133    def validate(self, value: Any, model_instance: Any) -> None:
134        super().validate(value, model_instance)
135        try:
136            json.dumps(value, cls=self.encoder)
137        except TypeError:
138            raise exceptions.ValidationError(
139                self.error_messages["invalid"],
140                code="invalid",
141                params={"value": value},
142            )
143
144    def value_to_string(self, obj: Any) -> Any:
145        return self.value_from_object(obj)
146
147
148class DataContains(FieldGetDbPrepValueMixin, OperatorLookup):
149    lookup_name = "contains"
150    # PostgreSQL @> operator checks if left JSON contains right JSON.
151    operator = "@>"
152
153
154class ContainedBy(FieldGetDbPrepValueMixin, OperatorLookup):
155    lookup_name = "contained_by"
156    # PostgreSQL <@ operator checks if left JSON is contained by right JSON.
157    operator = "<@"
158
159
160class HasKeyLookup(OperatorLookup):
161    """Lookup for checking if a JSON field has a key."""
162
163    logical_operator: str | None = None
164
165    def as_sql(
166        self, compiler: SQLCompiler, connection: DatabaseWrapper
167    ) -> tuple[str, tuple[Any, ...]]:
168        # Handle KeyTransform on RHS by expanding it into LHS chain.
169        if isinstance(self.rhs, KeyTransform):
170            *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
171            for key in rhs_key_transforms[:-1]:
172                self.lhs = KeyTransform(key, self.lhs)
173            self.rhs = rhs_key_transforms[-1]
174        return super().as_sql(compiler, connection)
175
176
177class HasKey(HasKeyLookup):
178    lookup_name = "has_key"
179    # PostgreSQL ? operator checks if key exists.
180    operator = "?"
181    prepare_rhs = False
182
183
184class HasKeys(HasKeyLookup):
185    lookup_name = "has_keys"
186    # PostgreSQL ?& operator checks if all keys exist.
187    operator = "?&"
188    logical_operator = " AND "
189
190    def get_prep_lookup(self) -> list[str]:
191        return [str(item) for item in self.rhs]
192
193
194class HasAnyKeys(HasKeys):
195    lookup_name = "has_any_keys"
196    # PostgreSQL ?| operator checks if any key exists.
197    operator = "?|"
198    logical_operator = " OR "
199
200
201class JSONExact(lookups.Exact):
202    can_use_none_as_rhs = True
203
204    def process_rhs(
205        self, compiler: SQLCompiler, connection: DatabaseWrapper
206    ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
207        rhs, rhs_params = super().process_rhs(compiler, connection)
208        if isinstance(rhs, str):
209            # Treat None lookup values as null.
210            if rhs == "%s" and rhs_params == [None]:
211                rhs_params = ["null"]
212            return rhs, rhs_params
213        else:
214            return rhs, rhs_params
215
216
217class JSONIContains(lookups.IContains):
218    pass
219
220
221JSONField.register_lookup(DataContains)
222JSONField.register_lookup(ContainedBy)
223JSONField.register_lookup(HasKey)
224JSONField.register_lookup(HasKeys)
225JSONField.register_lookup(HasAnyKeys)
226JSONField.register_lookup(JSONExact)
227JSONField.register_lookup(JSONIContains)
228
229
230class KeyTransform(Transform):
231    # PostgreSQL -> operator extracts JSON object field as JSON.
232    operator = "->"
233    # PostgreSQL #> operator extracts nested JSON path as JSON.
234    nested_operator = "#>"
235
236    def __init__(self, key_name: str, *args: Any, **kwargs: Any):
237        super().__init__(*args, **kwargs)
238        self.key_name = str(key_name)
239
240    def preprocess_lhs(
241        self, compiler: SQLCompiler, connection: DatabaseWrapper
242    ) -> tuple[str, tuple[Any, ...], list[str]]:
243        key_transforms = [self.key_name]
244        previous = self.lhs
245        while isinstance(previous, KeyTransform):
246            key_transforms.insert(0, previous.key_name)
247            previous = previous.lhs
248        lhs, params = compiler.compile(previous)
249        return lhs, params, key_transforms
250
251    def as_sql(
252        self,
253        compiler: SQLCompiler,
254        connection: DatabaseWrapper,
255        function: str | None = None,
256        template: str | None = None,
257        arg_joiner: str | None = None,
258        **extra_context: Any,
259    ) -> tuple[str, list[Any]]:
260        lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
261        if len(key_transforms) > 1:
262            sql = f"({lhs} {self.nested_operator} %s)"
263            return sql, list(params) + [key_transforms]
264        try:
265            lookup = int(self.key_name)
266        except ValueError:
267            lookup = self.key_name
268        return f"({lhs} {self.operator} %s)", list(params) + [lookup]
269
270
271class KeyTextTransform(KeyTransform):
272    # PostgreSQL ->> operator extracts JSON object field as text.
273    operator = "->>"
274    # PostgreSQL #>> operator extracts nested JSON path as text.
275    nested_operator = "#>>"
276    output_field = TextField()
277
278    @classmethod
279    def from_lookup(cls, lookup: str) -> Any:
280        transform, *keys = lookup.split(LOOKUP_SEP)
281        if not keys:
282            raise ValueError("Lookup must contain key or index transforms.")
283        for key in keys:
284            transform = cls(key, transform)
285        return transform
286
287
288KT = KeyTextTransform.from_lookup
289
290
291class KeyTransformTextLookupMixin(Lookup):
292    """
293    Mixin for lookups expecting text LHS from a JSONField key lookup.
294    Uses the ->> operator to extract JSON values as text.
295    """
296
297    def __init__(self, key_transform: Any, *args: Any, **kwargs: Any):
298        if not isinstance(key_transform, KeyTransform):
299            raise TypeError(
300                "Transform should be an instance of KeyTransform in order to "
301                "use this lookup."
302            )
303        key_text_transform = KeyTextTransform(
304            key_transform.key_name,
305            *key_transform.source_expressions,
306            **key_transform.extra,
307        )
308        super().__init__(key_text_transform, *args, **kwargs)
309
310
311class KeyTransformIsNull(lookups.IsNull):
312    # key__isnull=False is the same as has_key='key'
313    pass
314
315
316class KeyTransformIn(lookups.In):
317    def resolve_expression_parameter(
318        self,
319        compiler: SQLCompiler,
320        connection: DatabaseWrapper,
321        sql: str,
322        param: Any,
323    ) -> tuple[str, list[Any]]:
324        sql, params = super().resolve_expression_parameter(
325            compiler,
326            connection,
327            sql,
328            param,
329        )
330        return sql, list(params)
331
332
333class KeyTransformExact(JSONExact):
334    def process_rhs(
335        self, compiler: SQLCompiler, connection: DatabaseWrapper
336    ) -> tuple[str, list[Any]] | tuple[list[str], list[Any]]:
337        if isinstance(self.rhs, KeyTransform):
338            return super(lookups.Exact, self).process_rhs(compiler, connection)
339        return super().process_rhs(compiler, connection)
340
341
342class KeyTransformIExact(KeyTransformTextLookupMixin, lookups.IExact):
343    pass
344
345
346class KeyTransformIContains(KeyTransformTextLookupMixin, lookups.IContains):
347    pass
348
349
350class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
351    pass
352
353
354class KeyTransformIStartsWith(KeyTransformTextLookupMixin, lookups.IStartsWith):
355    pass
356
357
358class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
359    pass
360
361
362class KeyTransformIEndsWith(KeyTransformTextLookupMixin, lookups.IEndsWith):
363    pass
364
365
366class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
367    pass
368
369
370class KeyTransformIRegex(KeyTransformTextLookupMixin, lookups.IRegex):
371    pass
372
373
374class KeyTransformLt(lookups.LessThan):
375    pass
376
377
378class KeyTransformLte(lookups.LessThanOrEqual):
379    pass
380
381
382class KeyTransformGt(lookups.GreaterThan):
383    pass
384
385
386class KeyTransformGte(lookups.GreaterThanOrEqual):
387    pass
388
389
390KeyTransform.register_lookup(KeyTransformIn)
391KeyTransform.register_lookup(KeyTransformExact)
392KeyTransform.register_lookup(KeyTransformIExact)
393KeyTransform.register_lookup(KeyTransformIsNull)
394KeyTransform.register_lookup(KeyTransformIContains)
395KeyTransform.register_lookup(KeyTransformStartsWith)
396KeyTransform.register_lookup(KeyTransformIStartsWith)
397KeyTransform.register_lookup(KeyTransformEndsWith)
398KeyTransform.register_lookup(KeyTransformIEndsWith)
399KeyTransform.register_lookup(KeyTransformRegex)
400KeyTransform.register_lookup(KeyTransformIRegex)
401
402KeyTransform.register_lookup(KeyTransformLt)
403KeyTransform.register_lookup(KeyTransformLte)
404KeyTransform.register_lookup(KeyTransformGt)
405KeyTransform.register_lookup(KeyTransformGte)
406
407
408class KeyTransformFactory:
409    def __init__(self, key_name: str):
410        self.key_name = key_name
411
412    def __call__(self, *args: Any, **kwargs: Any) -> KeyTransform:
413        return KeyTransform(self.key_name, *args, **kwargs)