Plain is headed towards 1.0! Subscribe for development updates →

  1import json
  2import warnings
  3
  4from plain import exceptions, preflight
  5from plain.models import expressions, lookups
  6from plain.models.constants import LOOKUP_SEP
  7from plain.models.db import NotSupportedError, connections, router
  8from plain.models.fields import TextField
  9from plain.models.lookups import (
 10    FieldGetDbPrepValueMixin,
 11    PostgresOperatorLookup,
 12    Transform,
 13)
 14from plain.utils.deprecation import RemovedInDjango51Warning
 15
 16from . import Field
 17from .mixins import CheckFieldDefaultMixin
 18
 19__all__ = ["JSONField"]
 20
 21
 22class JSONField(CheckFieldDefaultMixin, Field):
 23    empty_strings_allowed = False
 24    description = "A JSON object"
 25    default_error_messages = {
 26        "invalid": "Value must be valid JSON.",
 27    }
 28    _default_hint = ("dict", "{}")
 29
 30    def __init__(
 31        self,
 32        name=None,
 33        encoder=None,
 34        decoder=None,
 35        **kwargs,
 36    ):
 37        if encoder and not callable(encoder):
 38            raise ValueError("The encoder parameter must be a callable object.")
 39        if decoder and not callable(decoder):
 40            raise ValueError("The decoder parameter must be a callable object.")
 41        self.encoder = encoder
 42        self.decoder = decoder
 43        super().__init__(name, **kwargs)
 44
 45    def check(self, **kwargs):
 46        errors = super().check(**kwargs)
 47        databases = kwargs.get("databases") or []
 48        errors.extend(self._check_supported(databases))
 49        return errors
 50
 51    def _check_supported(self, databases):
 52        errors = []
 53        for db in databases:
 54            if not router.allow_migrate_model(db, self.model):
 55                continue
 56            connection = connections[db]
 57            if (
 58                self.model._meta.required_db_vendor
 59                and self.model._meta.required_db_vendor != connection.vendor
 60            ):
 61                continue
 62            if not (
 63                "supports_json_field" in self.model._meta.required_db_features
 64                or connection.features.supports_json_field
 65            ):
 66                errors.append(
 67                    preflight.Error(
 68                        "%s does not support JSONFields." % connection.display_name,
 69                        obj=self.model,
 70                        id="fields.E180",
 71                    )
 72                )
 73        return errors
 74
 75    def deconstruct(self):
 76        name, path, args, kwargs = super().deconstruct()
 77        if self.encoder is not None:
 78            kwargs["encoder"] = self.encoder
 79        if self.decoder is not None:
 80            kwargs["decoder"] = self.decoder
 81        return name, path, args, kwargs
 82
 83    def from_db_value(self, value, expression, connection):
 84        if value is None:
 85            return value
 86        # Some backends (SQLite at least) extract non-string values in their
 87        # SQL datatypes.
 88        if isinstance(expression, KeyTransform) and not isinstance(value, str):
 89            return value
 90        try:
 91            return json.loads(value, cls=self.decoder)
 92        except json.JSONDecodeError:
 93            return value
 94
 95    def get_internal_type(self):
 96        return "JSONField"
 97
 98    def get_db_prep_value(self, value, connection, prepared=False):
 99        # RemovedInDjango51Warning: When the deprecation ends, replace with:
100        # if (
101        #     isinstance(value, expressions.Value)
102        #     and isinstance(value.output_field, JSONField)
103        # ):
104        #     value = value.value
105        # elif hasattr(value, "as_sql"): ...
106        if isinstance(value, expressions.Value):
107            if isinstance(value.value, str) and not isinstance(
108                value.output_field, JSONField
109            ):
110                try:
111                    value = json.loads(value.value, cls=self.decoder)
112                except json.JSONDecodeError:
113                    value = value.value
114                else:
115                    warnings.warn(
116                        "Providing an encoded JSON string via Value() is deprecated. "
117                        f"Use Value({value!r}, output_field=JSONField()) instead.",
118                        category=RemovedInDjango51Warning,
119                    )
120            elif isinstance(value.output_field, JSONField):
121                value = value.value
122            else:
123                return value
124        elif hasattr(value, "as_sql"):
125            return value
126        return connection.ops.adapt_json_value(value, self.encoder)
127
128    def get_db_prep_save(self, value, connection):
129        if value is None:
130            return value
131        return self.get_db_prep_value(value, connection)
132
133    def get_transform(self, name):
134        transform = super().get_transform(name)
135        if transform:
136            return transform
137        return KeyTransformFactory(name)
138
139    def validate(self, value, model_instance):
140        super().validate(value, model_instance)
141        try:
142            json.dumps(value, cls=self.encoder)
143        except TypeError:
144            raise exceptions.ValidationError(
145                self.error_messages["invalid"],
146                code="invalid",
147                params={"value": value},
148            )
149
150    def value_to_string(self, obj):
151        return self.value_from_object(obj)
152
153
154def compile_json_path(key_transforms, include_root=True):
155    path = ["$"] if include_root else []
156    for key_transform in key_transforms:
157        try:
158            num = int(key_transform)
159        except ValueError:  # non-integer
160            path.append(".")
161            path.append(json.dumps(key_transform))
162        else:
163            path.append("[%s]" % num)
164    return "".join(path)
165
166
167class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
168    lookup_name = "contains"
169    postgres_operator = "@>"
170
171    def as_sql(self, compiler, connection):
172        if not connection.features.supports_json_field_contains:
173            raise NotSupportedError(
174                "contains lookup is not supported on this database backend."
175            )
176        lhs, lhs_params = self.process_lhs(compiler, connection)
177        rhs, rhs_params = self.process_rhs(compiler, connection)
178        params = tuple(lhs_params) + tuple(rhs_params)
179        return f"JSON_CONTAINS({lhs}, {rhs})", params
180
181
182class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
183    lookup_name = "contained_by"
184    postgres_operator = "<@"
185
186    def as_sql(self, compiler, connection):
187        if not connection.features.supports_json_field_contains:
188            raise NotSupportedError(
189                "contained_by lookup is not supported on this database backend."
190            )
191        lhs, lhs_params = self.process_lhs(compiler, connection)
192        rhs, rhs_params = self.process_rhs(compiler, connection)
193        params = tuple(rhs_params) + tuple(lhs_params)
194        return f"JSON_CONTAINS({rhs}, {lhs})", params
195
196
197class HasKeyLookup(PostgresOperatorLookup):
198    logical_operator = None
199
200    def compile_json_path_final_key(self, key_transform):
201        # Compile the final key without interpreting ints as array elements.
202        return ".%s" % json.dumps(key_transform)
203
204    def as_sql(self, compiler, connection, template=None):
205        # Process JSON path from the left-hand side.
206        if isinstance(self.lhs, KeyTransform):
207            lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
208                compiler, connection
209            )
210            lhs_json_path = compile_json_path(lhs_key_transforms)
211        else:
212            lhs, lhs_params = self.process_lhs(compiler, connection)
213            lhs_json_path = "$"
214        sql = template % lhs
215        # Process JSON path from the right-hand side.
216        rhs = self.rhs
217        rhs_params = []
218        if not isinstance(rhs, list | tuple):
219            rhs = [rhs]
220        for key in rhs:
221            if isinstance(key, KeyTransform):
222                *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
223            else:
224                rhs_key_transforms = [key]
225            *rhs_key_transforms, final_key = rhs_key_transforms
226            rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False)
227            rhs_json_path += self.compile_json_path_final_key(final_key)
228            rhs_params.append(lhs_json_path + rhs_json_path)
229        # Add condition for each key.
230        if self.logical_operator:
231            sql = "(%s)" % self.logical_operator.join([sql] * len(rhs_params))
232        return sql, tuple(lhs_params) + tuple(rhs_params)
233
234    def as_mysql(self, compiler, connection):
235        return self.as_sql(
236            compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
237        )
238
239    def as_postgresql(self, compiler, connection):
240        if isinstance(self.rhs, KeyTransform):
241            *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
242            for key in rhs_key_transforms[:-1]:
243                self.lhs = KeyTransform(key, self.lhs)
244            self.rhs = rhs_key_transforms[-1]
245        return super().as_postgresql(compiler, connection)
246
247    def as_sqlite(self, compiler, connection):
248        return self.as_sql(
249            compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
250        )
251
252
253class HasKey(HasKeyLookup):
254    lookup_name = "has_key"
255    postgres_operator = "?"
256    prepare_rhs = False
257
258
259class HasKeys(HasKeyLookup):
260    lookup_name = "has_keys"
261    postgres_operator = "?&"
262    logical_operator = " AND "
263
264    def get_prep_lookup(self):
265        return [str(item) for item in self.rhs]
266
267
268class HasAnyKeys(HasKeys):
269    lookup_name = "has_any_keys"
270    postgres_operator = "?|"
271    logical_operator = " OR "
272
273
274class HasKeyOrArrayIndex(HasKey):
275    def compile_json_path_final_key(self, key_transform):
276        return compile_json_path([key_transform], include_root=False)
277
278
279class CaseInsensitiveMixin:
280    """
281    Mixin to allow case-insensitive comparison of JSON values on MySQL.
282    MySQL handles strings used in JSON context using the utf8mb4_bin collation.
283    Because utf8mb4_bin is a binary collation, comparison of JSON values is
284    case-sensitive.
285    """
286
287    def process_lhs(self, compiler, connection):
288        lhs, lhs_params = super().process_lhs(compiler, connection)
289        if connection.vendor == "mysql":
290            return "LOWER(%s)" % lhs, lhs_params
291        return lhs, lhs_params
292
293    def process_rhs(self, compiler, connection):
294        rhs, rhs_params = super().process_rhs(compiler, connection)
295        if connection.vendor == "mysql":
296            return "LOWER(%s)" % rhs, rhs_params
297        return rhs, rhs_params
298
299
300class JSONExact(lookups.Exact):
301    can_use_none_as_rhs = True
302
303    def process_rhs(self, compiler, connection):
304        rhs, rhs_params = super().process_rhs(compiler, connection)
305        # Treat None lookup values as null.
306        if rhs == "%s" and rhs_params == [None]:
307            rhs_params = ["null"]
308        if connection.vendor == "mysql":
309            func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
310            rhs %= tuple(func)
311        return rhs, rhs_params
312
313
314class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
315    pass
316
317
318JSONField.register_lookup(DataContains)
319JSONField.register_lookup(ContainedBy)
320JSONField.register_lookup(HasKey)
321JSONField.register_lookup(HasKeys)
322JSONField.register_lookup(HasAnyKeys)
323JSONField.register_lookup(JSONExact)
324JSONField.register_lookup(JSONIContains)
325
326
327class KeyTransform(Transform):
328    postgres_operator = "->"
329    postgres_nested_operator = "#>"
330
331    def __init__(self, key_name, *args, **kwargs):
332        super().__init__(*args, **kwargs)
333        self.key_name = str(key_name)
334
335    def preprocess_lhs(self, compiler, connection):
336        key_transforms = [self.key_name]
337        previous = self.lhs
338        while isinstance(previous, KeyTransform):
339            key_transforms.insert(0, previous.key_name)
340            previous = previous.lhs
341        lhs, params = compiler.compile(previous)
342        return lhs, params, key_transforms
343
344    def as_mysql(self, compiler, connection):
345        lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
346        json_path = compile_json_path(key_transforms)
347        return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,)
348
349    def as_postgresql(self, compiler, connection):
350        lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
351        if len(key_transforms) > 1:
352            sql = f"({lhs} {self.postgres_nested_operator} %s)"
353            return sql, tuple(params) + (key_transforms,)
354        try:
355            lookup = int(self.key_name)
356        except ValueError:
357            lookup = self.key_name
358        return f"({lhs} {self.postgres_operator} %s)", tuple(params) + (lookup,)
359
360    def as_sqlite(self, compiler, connection):
361        lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
362        json_path = compile_json_path(key_transforms)
363        datatype_values = ",".join(
364            [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
365        )
366        return (
367            f"(CASE WHEN JSON_TYPE({lhs}, %s) IN ({datatype_values}) "
368            f"THEN JSON_TYPE({lhs}, %s) ELSE JSON_EXTRACT({lhs}, %s) END)"
369        ), (tuple(params) + (json_path,)) * 3
370
371
372class KeyTextTransform(KeyTransform):
373    postgres_operator = "->>"
374    postgres_nested_operator = "#>>"
375    output_field = TextField()
376
377    def as_mysql(self, compiler, connection):
378        if connection.mysql_is_mariadb:
379            # MariaDB doesn't support -> and ->> operators (see MDEV-13594).
380            sql, params = super().as_mysql(compiler, connection)
381            return "JSON_UNQUOTE(%s)" % sql, params
382        else:
383            lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
384            json_path = compile_json_path(key_transforms)
385            return "(%s ->> %%s)" % lhs, tuple(params) + (json_path,)
386
387    @classmethod
388    def from_lookup(cls, lookup):
389        transform, *keys = lookup.split(LOOKUP_SEP)
390        if not keys:
391            raise ValueError("Lookup must contain key or index transforms.")
392        for key in keys:
393            transform = cls(key, transform)
394        return transform
395
396
397KT = KeyTextTransform.from_lookup
398
399
400class KeyTransformTextLookupMixin:
401    """
402    Mixin for combining with a lookup expecting a text lhs from a JSONField
403    key lookup. On PostgreSQL, make use of the ->> operator instead of casting
404    key values to text and performing the lookup on the resulting
405    representation.
406    """
407
408    def __init__(self, key_transform, *args, **kwargs):
409        if not isinstance(key_transform, KeyTransform):
410            raise TypeError(
411                "Transform should be an instance of KeyTransform in order to "
412                "use this lookup."
413            )
414        key_text_transform = KeyTextTransform(
415            key_transform.key_name,
416            *key_transform.source_expressions,
417            **key_transform.extra,
418        )
419        super().__init__(key_text_transform, *args, **kwargs)
420
421
422class KeyTransformIsNull(lookups.IsNull):
423    # key__isnull=False is the same as has_key='key'
424    def as_sqlite(self, compiler, connection):
425        template = "JSON_TYPE(%s, %%s) IS NULL"
426        if not self.rhs:
427            template = "JSON_TYPE(%s, %%s) IS NOT NULL"
428        return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql(
429            compiler,
430            connection,
431            template=template,
432        )
433
434
435class KeyTransformIn(lookups.In):
436    def resolve_expression_parameter(self, compiler, connection, sql, param):
437        sql, params = super().resolve_expression_parameter(
438            compiler,
439            connection,
440            sql,
441            param,
442        )
443        if (
444            not hasattr(param, "as_sql")
445            and not connection.features.has_native_json_field
446        ):
447            if connection.vendor == "mysql" or (
448                connection.vendor == "sqlite"
449                and params[0] not in connection.ops.jsonfield_datatype_values
450            ):
451                sql = "JSON_EXTRACT(%s, '$')"
452        if connection.vendor == "mysql" and connection.mysql_is_mariadb:
453            sql = "JSON_UNQUOTE(%s)" % sql
454        return sql, params
455
456
457class KeyTransformExact(JSONExact):
458    def process_rhs(self, compiler, connection):
459        if isinstance(self.rhs, KeyTransform):
460            return super(lookups.Exact, self).process_rhs(compiler, connection)
461        rhs, rhs_params = super().process_rhs(compiler, connection)
462        if connection.vendor == "sqlite":
463            func = []
464            for value in rhs_params:
465                if value in connection.ops.jsonfield_datatype_values:
466                    func.append("%s")
467                else:
468                    func.append("JSON_EXTRACT(%s, '$')")
469            rhs %= tuple(func)
470        return rhs, rhs_params
471
472
473class KeyTransformIExact(
474    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
475):
476    pass
477
478
479class KeyTransformIContains(
480    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
481):
482    pass
483
484
485class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
486    pass
487
488
489class KeyTransformIStartsWith(
490    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
491):
492    pass
493
494
495class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
496    pass
497
498
499class KeyTransformIEndsWith(
500    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
501):
502    pass
503
504
505class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
506    pass
507
508
509class KeyTransformIRegex(
510    CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
511):
512    pass
513
514
515class KeyTransformNumericLookupMixin:
516    def process_rhs(self, compiler, connection):
517        rhs, rhs_params = super().process_rhs(compiler, connection)
518        if not connection.features.has_native_json_field:
519            rhs_params = [json.loads(value) for value in rhs_params]
520        return rhs, rhs_params
521
522
523class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
524    pass
525
526
527class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
528    pass
529
530
531class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
532    pass
533
534
535class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
536    pass
537
538
539KeyTransform.register_lookup(KeyTransformIn)
540KeyTransform.register_lookup(KeyTransformExact)
541KeyTransform.register_lookup(KeyTransformIExact)
542KeyTransform.register_lookup(KeyTransformIsNull)
543KeyTransform.register_lookup(KeyTransformIContains)
544KeyTransform.register_lookup(KeyTransformStartsWith)
545KeyTransform.register_lookup(KeyTransformIStartsWith)
546KeyTransform.register_lookup(KeyTransformEndsWith)
547KeyTransform.register_lookup(KeyTransformIEndsWith)
548KeyTransform.register_lookup(KeyTransformRegex)
549KeyTransform.register_lookup(KeyTransformIRegex)
550
551KeyTransform.register_lookup(KeyTransformLt)
552KeyTransform.register_lookup(KeyTransformLte)
553KeyTransform.register_lookup(KeyTransformGt)
554KeyTransform.register_lookup(KeyTransformGte)
555
556
557class KeyTransformFactory:
558    def __init__(self, key_name):
559        self.key_name = key_name
560
561    def __call__(self, *args, **kwargs):
562        return KeyTransform(self.key_name, *args, **kwargs)