1import json
2import warnings
3
4from plain import exceptions, preflight
5from plain.models import expressions, lookups
6from plain.models.constants import LOOKUP_SEP
7from plain.models.db import NotSupportedError, connections, router
8from plain.models.fields import TextField
9from plain.models.lookups import (
10 FieldGetDbPrepValueMixin,
11 PostgresOperatorLookup,
12 Transform,
13)
14from plain.utils.deprecation import RemovedInDjango51Warning
15
16from . import Field
17from .mixins import CheckFieldDefaultMixin
18
19__all__ = ["JSONField"]
20
21
22class JSONField(CheckFieldDefaultMixin, Field):
23 empty_strings_allowed = False
24 description = "A JSON object"
25 default_error_messages = {
26 "invalid": "Value must be valid JSON.",
27 }
28 _default_hint = ("dict", "{}")
29
30 def __init__(
31 self,
32 name=None,
33 encoder=None,
34 decoder=None,
35 **kwargs,
36 ):
37 if encoder and not callable(encoder):
38 raise ValueError("The encoder parameter must be a callable object.")
39 if decoder and not callable(decoder):
40 raise ValueError("The decoder parameter must be a callable object.")
41 self.encoder = encoder
42 self.decoder = decoder
43 super().__init__(name, **kwargs)
44
45 def check(self, **kwargs):
46 errors = super().check(**kwargs)
47 databases = kwargs.get("databases") or []
48 errors.extend(self._check_supported(databases))
49 return errors
50
51 def _check_supported(self, databases):
52 errors = []
53 for db in databases:
54 if not router.allow_migrate_model(db, self.model):
55 continue
56 connection = connections[db]
57 if (
58 self.model._meta.required_db_vendor
59 and self.model._meta.required_db_vendor != connection.vendor
60 ):
61 continue
62 if not (
63 "supports_json_field" in self.model._meta.required_db_features
64 or connection.features.supports_json_field
65 ):
66 errors.append(
67 preflight.Error(
68 "%s does not support JSONFields." % connection.display_name,
69 obj=self.model,
70 id="fields.E180",
71 )
72 )
73 return errors
74
75 def deconstruct(self):
76 name, path, args, kwargs = super().deconstruct()
77 if self.encoder is not None:
78 kwargs["encoder"] = self.encoder
79 if self.decoder is not None:
80 kwargs["decoder"] = self.decoder
81 return name, path, args, kwargs
82
83 def from_db_value(self, value, expression, connection):
84 if value is None:
85 return value
86 # Some backends (SQLite at least) extract non-string values in their
87 # SQL datatypes.
88 if isinstance(expression, KeyTransform) and not isinstance(value, str):
89 return value
90 try:
91 return json.loads(value, cls=self.decoder)
92 except json.JSONDecodeError:
93 return value
94
95 def get_internal_type(self):
96 return "JSONField"
97
98 def get_db_prep_value(self, value, connection, prepared=False):
99 # RemovedInDjango51Warning: When the deprecation ends, replace with:
100 # if (
101 # isinstance(value, expressions.Value)
102 # and isinstance(value.output_field, JSONField)
103 # ):
104 # value = value.value
105 # elif hasattr(value, "as_sql"): ...
106 if isinstance(value, expressions.Value):
107 if isinstance(value.value, str) and not isinstance(
108 value.output_field, JSONField
109 ):
110 try:
111 value = json.loads(value.value, cls=self.decoder)
112 except json.JSONDecodeError:
113 value = value.value
114 else:
115 warnings.warn(
116 "Providing an encoded JSON string via Value() is deprecated. "
117 f"Use Value({value!r}, output_field=JSONField()) instead.",
118 category=RemovedInDjango51Warning,
119 )
120 elif isinstance(value.output_field, JSONField):
121 value = value.value
122 else:
123 return value
124 elif hasattr(value, "as_sql"):
125 return value
126 return connection.ops.adapt_json_value(value, self.encoder)
127
128 def get_db_prep_save(self, value, connection):
129 if value is None:
130 return value
131 return self.get_db_prep_value(value, connection)
132
133 def get_transform(self, name):
134 transform = super().get_transform(name)
135 if transform:
136 return transform
137 return KeyTransformFactory(name)
138
139 def validate(self, value, model_instance):
140 super().validate(value, model_instance)
141 try:
142 json.dumps(value, cls=self.encoder)
143 except TypeError:
144 raise exceptions.ValidationError(
145 self.error_messages["invalid"],
146 code="invalid",
147 params={"value": value},
148 )
149
150 def value_to_string(self, obj):
151 return self.value_from_object(obj)
152
153
154def compile_json_path(key_transforms, include_root=True):
155 path = ["$"] if include_root else []
156 for key_transform in key_transforms:
157 try:
158 num = int(key_transform)
159 except ValueError: # non-integer
160 path.append(".")
161 path.append(json.dumps(key_transform))
162 else:
163 path.append("[%s]" % num)
164 return "".join(path)
165
166
167class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
168 lookup_name = "contains"
169 postgres_operator = "@>"
170
171 def as_sql(self, compiler, connection):
172 if not connection.features.supports_json_field_contains:
173 raise NotSupportedError(
174 "contains lookup is not supported on this database backend."
175 )
176 lhs, lhs_params = self.process_lhs(compiler, connection)
177 rhs, rhs_params = self.process_rhs(compiler, connection)
178 params = tuple(lhs_params) + tuple(rhs_params)
179 return f"JSON_CONTAINS({lhs}, {rhs})", params
180
181
182class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
183 lookup_name = "contained_by"
184 postgres_operator = "<@"
185
186 def as_sql(self, compiler, connection):
187 if not connection.features.supports_json_field_contains:
188 raise NotSupportedError(
189 "contained_by lookup is not supported on this database backend."
190 )
191 lhs, lhs_params = self.process_lhs(compiler, connection)
192 rhs, rhs_params = self.process_rhs(compiler, connection)
193 params = tuple(rhs_params) + tuple(lhs_params)
194 return f"JSON_CONTAINS({rhs}, {lhs})", params
195
196
197class HasKeyLookup(PostgresOperatorLookup):
198 logical_operator = None
199
200 def compile_json_path_final_key(self, key_transform):
201 # Compile the final key without interpreting ints as array elements.
202 return ".%s" % json.dumps(key_transform)
203
204 def as_sql(self, compiler, connection, template=None):
205 # Process JSON path from the left-hand side.
206 if isinstance(self.lhs, KeyTransform):
207 lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
208 compiler, connection
209 )
210 lhs_json_path = compile_json_path(lhs_key_transforms)
211 else:
212 lhs, lhs_params = self.process_lhs(compiler, connection)
213 lhs_json_path = "$"
214 sql = template % lhs
215 # Process JSON path from the right-hand side.
216 rhs = self.rhs
217 rhs_params = []
218 if not isinstance(rhs, list | tuple):
219 rhs = [rhs]
220 for key in rhs:
221 if isinstance(key, KeyTransform):
222 *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
223 else:
224 rhs_key_transforms = [key]
225 *rhs_key_transforms, final_key = rhs_key_transforms
226 rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False)
227 rhs_json_path += self.compile_json_path_final_key(final_key)
228 rhs_params.append(lhs_json_path + rhs_json_path)
229 # Add condition for each key.
230 if self.logical_operator:
231 sql = "(%s)" % self.logical_operator.join([sql] * len(rhs_params))
232 return sql, tuple(lhs_params) + tuple(rhs_params)
233
234 def as_mysql(self, compiler, connection):
235 return self.as_sql(
236 compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
237 )
238
239 def as_postgresql(self, compiler, connection):
240 if isinstance(self.rhs, KeyTransform):
241 *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
242 for key in rhs_key_transforms[:-1]:
243 self.lhs = KeyTransform(key, self.lhs)
244 self.rhs = rhs_key_transforms[-1]
245 return super().as_postgresql(compiler, connection)
246
247 def as_sqlite(self, compiler, connection):
248 return self.as_sql(
249 compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
250 )
251
252
253class HasKey(HasKeyLookup):
254 lookup_name = "has_key"
255 postgres_operator = "?"
256 prepare_rhs = False
257
258
259class HasKeys(HasKeyLookup):
260 lookup_name = "has_keys"
261 postgres_operator = "?&"
262 logical_operator = " AND "
263
264 def get_prep_lookup(self):
265 return [str(item) for item in self.rhs]
266
267
268class HasAnyKeys(HasKeys):
269 lookup_name = "has_any_keys"
270 postgres_operator = "?|"
271 logical_operator = " OR "
272
273
274class HasKeyOrArrayIndex(HasKey):
275 def compile_json_path_final_key(self, key_transform):
276 return compile_json_path([key_transform], include_root=False)
277
278
279class CaseInsensitiveMixin:
280 """
281 Mixin to allow case-insensitive comparison of JSON values on MySQL.
282 MySQL handles strings used in JSON context using the utf8mb4_bin collation.
283 Because utf8mb4_bin is a binary collation, comparison of JSON values is
284 case-sensitive.
285 """
286
287 def process_lhs(self, compiler, connection):
288 lhs, lhs_params = super().process_lhs(compiler, connection)
289 if connection.vendor == "mysql":
290 return "LOWER(%s)" % lhs, lhs_params
291 return lhs, lhs_params
292
293 def process_rhs(self, compiler, connection):
294 rhs, rhs_params = super().process_rhs(compiler, connection)
295 if connection.vendor == "mysql":
296 return "LOWER(%s)" % rhs, rhs_params
297 return rhs, rhs_params
298
299
300class JSONExact(lookups.Exact):
301 can_use_none_as_rhs = True
302
303 def process_rhs(self, compiler, connection):
304 rhs, rhs_params = super().process_rhs(compiler, connection)
305 # Treat None lookup values as null.
306 if rhs == "%s" and rhs_params == [None]:
307 rhs_params = ["null"]
308 if connection.vendor == "mysql":
309 func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
310 rhs %= tuple(func)
311 return rhs, rhs_params
312
313
314class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
315 pass
316
317
318JSONField.register_lookup(DataContains)
319JSONField.register_lookup(ContainedBy)
320JSONField.register_lookup(HasKey)
321JSONField.register_lookup(HasKeys)
322JSONField.register_lookup(HasAnyKeys)
323JSONField.register_lookup(JSONExact)
324JSONField.register_lookup(JSONIContains)
325
326
327class KeyTransform(Transform):
328 postgres_operator = "->"
329 postgres_nested_operator = "#>"
330
331 def __init__(self, key_name, *args, **kwargs):
332 super().__init__(*args, **kwargs)
333 self.key_name = str(key_name)
334
335 def preprocess_lhs(self, compiler, connection):
336 key_transforms = [self.key_name]
337 previous = self.lhs
338 while isinstance(previous, KeyTransform):
339 key_transforms.insert(0, previous.key_name)
340 previous = previous.lhs
341 lhs, params = compiler.compile(previous)
342 return lhs, params, key_transforms
343
344 def as_mysql(self, compiler, connection):
345 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
346 json_path = compile_json_path(key_transforms)
347 return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,)
348
349 def as_postgresql(self, compiler, connection):
350 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
351 if len(key_transforms) > 1:
352 sql = f"({lhs} {self.postgres_nested_operator} %s)"
353 return sql, tuple(params) + (key_transforms,)
354 try:
355 lookup = int(self.key_name)
356 except ValueError:
357 lookup = self.key_name
358 return f"({lhs} {self.postgres_operator} %s)", tuple(params) + (lookup,)
359
360 def as_sqlite(self, compiler, connection):
361 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
362 json_path = compile_json_path(key_transforms)
363 datatype_values = ",".join(
364 [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
365 )
366 return (
367 f"(CASE WHEN JSON_TYPE({lhs}, %s) IN ({datatype_values}) "
368 f"THEN JSON_TYPE({lhs}, %s) ELSE JSON_EXTRACT({lhs}, %s) END)"
369 ), (tuple(params) + (json_path,)) * 3
370
371
372class KeyTextTransform(KeyTransform):
373 postgres_operator = "->>"
374 postgres_nested_operator = "#>>"
375 output_field = TextField()
376
377 def as_mysql(self, compiler, connection):
378 if connection.mysql_is_mariadb:
379 # MariaDB doesn't support -> and ->> operators (see MDEV-13594).
380 sql, params = super().as_mysql(compiler, connection)
381 return "JSON_UNQUOTE(%s)" % sql, params
382 else:
383 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
384 json_path = compile_json_path(key_transforms)
385 return "(%s ->> %%s)" % lhs, tuple(params) + (json_path,)
386
387 @classmethod
388 def from_lookup(cls, lookup):
389 transform, *keys = lookup.split(LOOKUP_SEP)
390 if not keys:
391 raise ValueError("Lookup must contain key or index transforms.")
392 for key in keys:
393 transform = cls(key, transform)
394 return transform
395
396
397KT = KeyTextTransform.from_lookup
398
399
400class KeyTransformTextLookupMixin:
401 """
402 Mixin for combining with a lookup expecting a text lhs from a JSONField
403 key lookup. On PostgreSQL, make use of the ->> operator instead of casting
404 key values to text and performing the lookup on the resulting
405 representation.
406 """
407
408 def __init__(self, key_transform, *args, **kwargs):
409 if not isinstance(key_transform, KeyTransform):
410 raise TypeError(
411 "Transform should be an instance of KeyTransform in order to "
412 "use this lookup."
413 )
414 key_text_transform = KeyTextTransform(
415 key_transform.key_name,
416 *key_transform.source_expressions,
417 **key_transform.extra,
418 )
419 super().__init__(key_text_transform, *args, **kwargs)
420
421
422class KeyTransformIsNull(lookups.IsNull):
423 # key__isnull=False is the same as has_key='key'
424 def as_sqlite(self, compiler, connection):
425 template = "JSON_TYPE(%s, %%s) IS NULL"
426 if not self.rhs:
427 template = "JSON_TYPE(%s, %%s) IS NOT NULL"
428 return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql(
429 compiler,
430 connection,
431 template=template,
432 )
433
434
435class KeyTransformIn(lookups.In):
436 def resolve_expression_parameter(self, compiler, connection, sql, param):
437 sql, params = super().resolve_expression_parameter(
438 compiler,
439 connection,
440 sql,
441 param,
442 )
443 if (
444 not hasattr(param, "as_sql")
445 and not connection.features.has_native_json_field
446 ):
447 if connection.vendor == "mysql" or (
448 connection.vendor == "sqlite"
449 and params[0] not in connection.ops.jsonfield_datatype_values
450 ):
451 sql = "JSON_EXTRACT(%s, '$')"
452 if connection.vendor == "mysql" and connection.mysql_is_mariadb:
453 sql = "JSON_UNQUOTE(%s)" % sql
454 return sql, params
455
456
457class KeyTransformExact(JSONExact):
458 def process_rhs(self, compiler, connection):
459 if isinstance(self.rhs, KeyTransform):
460 return super(lookups.Exact, self).process_rhs(compiler, connection)
461 rhs, rhs_params = super().process_rhs(compiler, connection)
462 if connection.vendor == "sqlite":
463 func = []
464 for value in rhs_params:
465 if value in connection.ops.jsonfield_datatype_values:
466 func.append("%s")
467 else:
468 func.append("JSON_EXTRACT(%s, '$')")
469 rhs %= tuple(func)
470 return rhs, rhs_params
471
472
473class KeyTransformIExact(
474 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
475):
476 pass
477
478
479class KeyTransformIContains(
480 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
481):
482 pass
483
484
485class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
486 pass
487
488
489class KeyTransformIStartsWith(
490 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
491):
492 pass
493
494
495class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
496 pass
497
498
499class KeyTransformIEndsWith(
500 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
501):
502 pass
503
504
505class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
506 pass
507
508
509class KeyTransformIRegex(
510 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
511):
512 pass
513
514
515class KeyTransformNumericLookupMixin:
516 def process_rhs(self, compiler, connection):
517 rhs, rhs_params = super().process_rhs(compiler, connection)
518 if not connection.features.has_native_json_field:
519 rhs_params = [json.loads(value) for value in rhs_params]
520 return rhs, rhs_params
521
522
523class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
524 pass
525
526
527class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
528 pass
529
530
531class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
532 pass
533
534
535class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
536 pass
537
538
539KeyTransform.register_lookup(KeyTransformIn)
540KeyTransform.register_lookup(KeyTransformExact)
541KeyTransform.register_lookup(KeyTransformIExact)
542KeyTransform.register_lookup(KeyTransformIsNull)
543KeyTransform.register_lookup(KeyTransformIContains)
544KeyTransform.register_lookup(KeyTransformStartsWith)
545KeyTransform.register_lookup(KeyTransformIStartsWith)
546KeyTransform.register_lookup(KeyTransformEndsWith)
547KeyTransform.register_lookup(KeyTransformIEndsWith)
548KeyTransform.register_lookup(KeyTransformRegex)
549KeyTransform.register_lookup(KeyTransformIRegex)
550
551KeyTransform.register_lookup(KeyTransformLt)
552KeyTransform.register_lookup(KeyTransformLte)
553KeyTransform.register_lookup(KeyTransformGt)
554KeyTransform.register_lookup(KeyTransformGte)
555
556
557class KeyTransformFactory:
558 def __init__(self, key_name):
559 self.key_name = key_name
560
561 def __call__(self, *args, **kwargs):
562 return KeyTransform(self.key_name, *args, **kwargs)