1from __future__ import annotations
2
3import json
4from typing import TYPE_CHECKING, Any, cast
5
6from plain import exceptions, preflight
7from plain.models import expressions, lookups
8from plain.models.constants import LOOKUP_SEP
9from plain.models.db import NotSupportedError, db_connection
10from plain.models.fields import TextField
11from plain.models.lookups import (
12 FieldGetDbPrepValueMixin,
13 PostgresOperatorLookup,
14 Transform,
15)
16
17from . import Field
18from .mixins import CheckFieldDefaultMixin
19
20if TYPE_CHECKING:
21 from plain.models.backends.base.base import BaseDatabaseWrapper
22 from plain.models.backends.mysql.base import MySQLDatabaseWrapper
23 from plain.models.backends.sqlite3.base import SQLiteDatabaseWrapper
24 from plain.models.sql.compiler import SQLCompiler
25 from plain.preflight.results import PreflightResult
26
27__all__ = ["JSONField"]
28
29
30class JSONField(CheckFieldDefaultMixin, Field):
31 empty_strings_allowed = False
32 description = "A JSON object"
33 default_error_messages = {
34 "invalid": "Value must be valid JSON.",
35 }
36 _default_fix = ("dict", "{}")
37
38 def __init__(
39 self,
40 *,
41 encoder: type[json.JSONEncoder] | None = None,
42 decoder: type[json.JSONDecoder] | None = None,
43 **kwargs: Any,
44 ):
45 if encoder and not callable(encoder):
46 raise ValueError("The encoder parameter must be a callable object.")
47 if decoder and not callable(decoder):
48 raise ValueError("The decoder parameter must be a callable object.")
49 self.encoder = encoder
50 self.decoder = decoder
51 super().__init__(**kwargs)
52
53 def preflight(self, **kwargs: Any) -> list[PreflightResult]:
54 errors = super().preflight(**kwargs)
55 errors.extend(self._check_supported())
56 return errors
57
58 def _check_supported(self) -> list[PreflightResult]:
59 errors = []
60
61 if (
62 self.model.model_options.required_db_vendor
63 and self.model.model_options.required_db_vendor != db_connection.vendor
64 ):
65 return errors
66
67 if not (
68 "supports_json_field" in self.model.model_options.required_db_features
69 or db_connection.features.supports_json_field
70 ):
71 errors.append(
72 preflight.PreflightResult(
73 fix=f"{db_connection.display_name} does not support JSONFields. Consider using a TextField with JSON serialization or upgrade to a database that supports JSON fields.",
74 obj=self.model,
75 id="fields.json_field_unsupported",
76 )
77 )
78 return errors
79
80 def deconstruct(self) -> tuple[str, str, list[Any], dict[str, Any]]:
81 name, path, args, kwargs = super().deconstruct()
82 if self.encoder is not None:
83 kwargs["encoder"] = self.encoder
84 if self.decoder is not None:
85 kwargs["decoder"] = self.decoder
86 return name, path, args, kwargs
87
88 def from_db_value(
89 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
90 ) -> Any:
91 if value is None:
92 return value
93 # Some backends (SQLite at least) extract non-string values in their
94 # SQL datatypes.
95 if isinstance(expression, KeyTransform) and not isinstance(value, str):
96 return value
97 try:
98 return json.loads(value, cls=self.decoder)
99 except json.JSONDecodeError:
100 return value
101
102 def get_internal_type(self) -> str:
103 return "JSONField"
104
105 def get_db_prep_value(
106 self, value: Any, connection: BaseDatabaseWrapper, prepared: bool = False
107 ) -> Any:
108 if isinstance(value, expressions.Value) and isinstance(
109 value.output_field, JSONField
110 ):
111 value = value.value
112 elif hasattr(value, "as_sql"):
113 return value
114 return connection.ops.adapt_json_value(value, self.encoder)
115
116 def get_db_prep_save(self, value: Any, connection: BaseDatabaseWrapper) -> Any:
117 if value is None:
118 return value
119 return self.get_db_prep_value(value, connection)
120
121 def get_transform(self, name: str) -> KeyTransformFactory | type[Transform]:
122 transform = super().get_transform(name)
123 if transform:
124 return transform
125 return KeyTransformFactory(name)
126
127 def validate(self, value: Any, model_instance: Any) -> None:
128 super().validate(value, model_instance)
129 try:
130 json.dumps(value, cls=self.encoder)
131 except TypeError:
132 raise exceptions.ValidationError(
133 self.error_messages["invalid"],
134 code="invalid",
135 params={"value": value},
136 )
137
138 def value_to_string(self, obj: Any) -> Any:
139 return self.value_from_object(obj)
140
141
142def compile_json_path(key_transforms: list[Any], include_root: bool = True) -> str:
143 path = ["$"] if include_root else []
144 for key_transform in key_transforms:
145 try:
146 num = int(key_transform)
147 except ValueError: # non-integer
148 path.append(".")
149 path.append(json.dumps(key_transform))
150 else:
151 path.append(f"[{num}]")
152 return "".join(path)
153
154
155class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
156 lookup_name = "contains"
157 postgres_operator = "@>"
158
159 def as_sql(
160 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
161 ) -> tuple[str, tuple[Any, ...]]:
162 if not connection.features.supports_json_field_contains:
163 raise NotSupportedError(
164 "contains lookup is not supported on this database backend."
165 )
166 lhs, lhs_params = self.process_lhs(compiler, connection)
167 rhs, rhs_params = self.process_rhs(compiler, connection)
168 params = tuple(lhs_params) + tuple(rhs_params)
169 return f"JSON_CONTAINS({lhs}, {rhs})", params
170
171
172class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
173 lookup_name = "contained_by"
174 postgres_operator = "<@"
175
176 def as_sql(
177 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
178 ) -> tuple[str, tuple[Any, ...]]:
179 if not connection.features.supports_json_field_contains:
180 raise NotSupportedError(
181 "contained_by lookup is not supported on this database backend."
182 )
183 lhs, lhs_params = self.process_lhs(compiler, connection)
184 rhs, rhs_params = self.process_rhs(compiler, connection)
185 params = tuple(rhs_params) + tuple(lhs_params)
186 return f"JSON_CONTAINS({rhs}, {lhs})", params
187
188
189class HasKeyLookup(PostgresOperatorLookup):
190 logical_operator: str | None = None
191
192 def compile_json_path_final_key(self, key_transform: Any) -> str:
193 # Compile the final key without interpreting ints as array elements.
194 return f".{json.dumps(key_transform)}"
195
196 def as_sql(
197 self,
198 compiler: SQLCompiler,
199 connection: BaseDatabaseWrapper,
200 template: str | None = None,
201 ) -> tuple[str, tuple[Any, ...]]:
202 # Process JSON path from the left-hand side.
203 if isinstance(self.lhs, KeyTransform):
204 lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
205 compiler, connection
206 )
207 lhs_json_path = compile_json_path(lhs_key_transforms)
208 else:
209 lhs, lhs_params = self.process_lhs(compiler, connection)
210 lhs_json_path = "$"
211 sql = template % lhs
212 # Process JSON path from the right-hand side.
213 rhs = self.rhs
214 rhs_params = []
215 if not isinstance(rhs, list | tuple):
216 rhs = [rhs]
217 for key in rhs:
218 if isinstance(key, KeyTransform):
219 *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
220 else:
221 rhs_key_transforms = [key]
222 *rhs_key_transforms, final_key = rhs_key_transforms
223 rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False)
224 rhs_json_path += self.compile_json_path_final_key(final_key)
225 rhs_params.append(lhs_json_path + rhs_json_path)
226 # Add condition for each key.
227 if self.logical_operator:
228 sql = f"({self.logical_operator.join([sql] * len(rhs_params))})"
229 return sql, tuple(lhs_params) + tuple(rhs_params)
230
231 def as_mysql(
232 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
233 ) -> tuple[str, tuple[Any, ...]]:
234 return self.as_sql(
235 compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
236 )
237
238 def as_postgresql(
239 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
240 ) -> tuple[str, tuple[Any, ...]]:
241 if isinstance(self.rhs, KeyTransform):
242 *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
243 for key in rhs_key_transforms[:-1]:
244 self.lhs = KeyTransform(key, self.lhs)
245 self.rhs = rhs_key_transforms[-1]
246 return super().as_postgresql(compiler, connection)
247
248 def as_sqlite(
249 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
250 ) -> tuple[str, tuple[Any, ...]]:
251 return self.as_sql(
252 compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
253 )
254
255
256class HasKey(HasKeyLookup):
257 lookup_name = "has_key"
258 postgres_operator = "?"
259 prepare_rhs = False
260
261
262class HasKeys(HasKeyLookup):
263 lookup_name = "has_keys"
264 postgres_operator = "?&"
265 logical_operator = " AND "
266
267 def get_prep_lookup(self) -> list[str]:
268 return [str(item) for item in self.rhs]
269
270
271class HasAnyKeys(HasKeys):
272 lookup_name = "has_any_keys"
273 postgres_operator = "?|"
274 logical_operator = " OR "
275
276
277class HasKeyOrArrayIndex(HasKey):
278 def compile_json_path_final_key(self, key_transform: Any) -> str:
279 return compile_json_path([key_transform], include_root=False)
280
281
282class CaseInsensitiveMixin:
283 """
284 Mixin to allow case-insensitive comparison of JSON values on MySQL.
285 MySQL handles strings used in JSON context using the utf8mb4_bin collation.
286 Because utf8mb4_bin is a binary collation, comparison of JSON values is
287 case-sensitive.
288 """
289
290 def process_lhs(
291 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
292 ) -> tuple[str, list[Any]]:
293 lhs, lhs_params = super().process_lhs(compiler, connection) # type: ignore[misc]
294 if connection.vendor == "mysql":
295 return f"LOWER({lhs})", lhs_params
296 return lhs, lhs_params
297
298 def process_rhs(
299 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
300 ) -> tuple[str, list[Any]]:
301 rhs, rhs_params = super().process_rhs(compiler, connection) # type: ignore[misc]
302 if connection.vendor == "mysql":
303 return f"LOWER({rhs})", rhs_params
304 return rhs, rhs_params
305
306
307class JSONExact(lookups.Exact):
308 can_use_none_as_rhs = True
309
310 def process_rhs(
311 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
312 ) -> tuple[str, list[Any]]:
313 rhs, rhs_params = super().process_rhs(compiler, connection)
314 # Treat None lookup values as null.
315 if rhs == "%s" and rhs_params == [None]:
316 rhs_params = ["null"]
317 if connection.vendor == "mysql":
318 func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
319 rhs %= tuple(func)
320 return rhs, rhs_params
321
322
323class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
324 pass
325
326
327JSONField.register_lookup(DataContains)
328JSONField.register_lookup(ContainedBy)
329JSONField.register_lookup(HasKey)
330JSONField.register_lookup(HasKeys)
331JSONField.register_lookup(HasAnyKeys)
332JSONField.register_lookup(JSONExact)
333JSONField.register_lookup(JSONIContains)
334
335
336class KeyTransform(Transform):
337 postgres_operator = "->"
338 postgres_nested_operator = "#>"
339
340 def __init__(self, key_name: str, *args: Any, **kwargs: Any):
341 super().__init__(*args, **kwargs)
342 self.key_name = str(key_name)
343
344 def preprocess_lhs(
345 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
346 ) -> tuple[str, tuple[Any, ...], list[str]]:
347 key_transforms = [self.key_name]
348 previous = self.lhs
349 while isinstance(previous, KeyTransform):
350 key_transforms.insert(0, previous.key_name)
351 previous = previous.lhs
352 lhs, params = compiler.compile(previous)
353 return lhs, params, key_transforms
354
355 def as_mysql(
356 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
357 ) -> tuple[str, tuple[Any, ...]]:
358 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
359 json_path = compile_json_path(key_transforms)
360 return f"JSON_EXTRACT({lhs}, %s)", tuple(params) + (json_path,)
361
362 def as_postgresql(
363 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
364 ) -> tuple[str, tuple[Any, ...]]:
365 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
366 if len(key_transforms) > 1:
367 sql = f"({lhs} {self.postgres_nested_operator} %s)"
368 return sql, tuple(params) + (key_transforms,)
369 try:
370 lookup = int(self.key_name)
371 except ValueError:
372 lookup = self.key_name
373 return f"({lhs} {self.postgres_operator} %s)", tuple(params) + (lookup,)
374
375 def as_sqlite(
376 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
377 ) -> tuple[str, tuple[Any, ...]]:
378 sqlite_connection = cast(SQLiteDatabaseWrapper, connection)
379 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
380 json_path = compile_json_path(key_transforms)
381 datatype_values = ",".join(
382 [
383 repr(datatype)
384 for datatype in sqlite_connection.ops.jsonfield_datatype_values # type: ignore[attr-defined]
385 ]
386 )
387 return (
388 f"(CASE WHEN JSON_TYPE({lhs}, %s) IN ({datatype_values}) "
389 f"THEN JSON_TYPE({lhs}, %s) ELSE JSON_EXTRACT({lhs}, %s) END)"
390 ), (tuple(params) + (json_path,)) * 3
391
392
393class KeyTextTransform(KeyTransform):
394 postgres_operator = "->>"
395 postgres_nested_operator = "#>>"
396 output_field = TextField()
397
398 def as_mysql(
399 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
400 ) -> tuple[str, tuple[Any, ...]]:
401 mysql_connection = cast(MySQLDatabaseWrapper, connection)
402 if mysql_connection.mysql_is_mariadb:
403 # MariaDB doesn't support -> and ->> operators (see MDEV-13594).
404 sql, params = super().as_mysql(compiler, connection)
405 return f"JSON_UNQUOTE({sql})", params
406 else:
407 lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
408 json_path = compile_json_path(key_transforms)
409 return f"({lhs} ->> %s)", tuple(params) + (json_path,)
410
411 @classmethod
412 def from_lookup(cls, lookup: str) -> Any:
413 transform, *keys = lookup.split(LOOKUP_SEP)
414 if not keys:
415 raise ValueError("Lookup must contain key or index transforms.")
416 for key in keys:
417 transform = cls(key, transform)
418 return transform
419
420
421KT = KeyTextTransform.from_lookup
422
423
424class KeyTransformTextLookupMixin:
425 """
426 Mixin for combining with a lookup expecting a text lhs from a JSONField
427 key lookup. On PostgreSQL, make use of the ->> operator instead of casting
428 key values to text and performing the lookup on the resulting
429 representation.
430 """
431
432 def __init__(self, key_transform: Any, *args: Any, **kwargs: Any):
433 if not isinstance(key_transform, KeyTransform):
434 raise TypeError(
435 "Transform should be an instance of KeyTransform in order to "
436 "use this lookup."
437 )
438 key_text_transform = KeyTextTransform(
439 key_transform.key_name,
440 *key_transform.source_expressions,
441 **key_transform.extra,
442 )
443 super().__init__(key_text_transform, *args, **kwargs) # type: ignore[misc]
444
445
446class KeyTransformIsNull(lookups.IsNull):
447 # key__isnull=False is the same as has_key='key'
448 def as_sqlite(
449 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
450 ) -> tuple[str, tuple[Any, ...]]:
451 template = "JSON_TYPE(%s, %%s) IS NULL"
452 if not self.rhs:
453 template = "JSON_TYPE(%s, %%s) IS NOT NULL"
454 return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql(
455 compiler,
456 connection,
457 template=template,
458 )
459
460
461class KeyTransformIn(lookups.In):
462 def resolve_expression_parameter(
463 self,
464 compiler: SQLCompiler,
465 connection: BaseDatabaseWrapper,
466 sql: str,
467 param: Any,
468 ) -> tuple[str, tuple[Any, ...]]:
469 sql, params = super().resolve_expression_parameter(
470 compiler,
471 connection,
472 sql,
473 param,
474 )
475 if (
476 not hasattr(param, "as_sql")
477 and not connection.features.has_native_json_field
478 ):
479 if connection.vendor == "mysql":
480 sql = "JSON_EXTRACT(%s, '$')"
481 elif connection.vendor == "sqlite":
482 sqlite_connection = cast(SQLiteDatabaseWrapper, connection)
483 if params[0] not in sqlite_connection.ops.jsonfield_datatype_values: # type: ignore[attr-defined]
484 sql = "JSON_EXTRACT(%s, '$')"
485 if connection.vendor == "mysql":
486 mysql_connection = cast(MySQLDatabaseWrapper, connection)
487 if mysql_connection.mysql_is_mariadb:
488 sql = f"JSON_UNQUOTE({sql})"
489 return sql, params
490
491
492class KeyTransformExact(JSONExact):
493 def process_rhs(
494 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
495 ) -> tuple[str, list[Any]]:
496 if isinstance(self.rhs, KeyTransform):
497 return super(lookups.Exact, self).process_rhs(compiler, connection)
498 rhs, rhs_params = super().process_rhs(compiler, connection)
499 if connection.vendor == "sqlite":
500 sqlite_connection = cast(SQLiteDatabaseWrapper, connection)
501 func = []
502 for value in rhs_params:
503 if value in sqlite_connection.ops.jsonfield_datatype_values: # type: ignore[attr-defined]
504 func.append("%s")
505 else:
506 func.append("JSON_EXTRACT(%s, '$')")
507 rhs %= tuple(func)
508 return rhs, rhs_params
509
510
511class KeyTransformIExact(
512 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
513):
514 pass
515
516
517class KeyTransformIContains(
518 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
519):
520 pass
521
522
523class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
524 pass
525
526
527class KeyTransformIStartsWith(
528 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
529):
530 pass
531
532
533class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
534 pass
535
536
537class KeyTransformIEndsWith(
538 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
539):
540 pass
541
542
543class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
544 pass
545
546
547class KeyTransformIRegex(
548 CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
549):
550 pass
551
552
553class KeyTransformNumericLookupMixin:
554 def process_rhs(
555 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
556 ) -> tuple[str, list[Any]]:
557 rhs, rhs_params = super().process_rhs(compiler, connection) # type: ignore[misc]
558 if not connection.features.has_native_json_field:
559 rhs_params = [json.loads(value) for value in rhs_params]
560 return rhs, rhs_params
561
562
563class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
564 pass
565
566
567class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
568 pass
569
570
571class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
572 pass
573
574
575class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
576 pass
577
578
579KeyTransform.register_lookup(KeyTransformIn)
580KeyTransform.register_lookup(KeyTransformExact)
581KeyTransform.register_lookup(KeyTransformIExact)
582KeyTransform.register_lookup(KeyTransformIsNull)
583KeyTransform.register_lookup(KeyTransformIContains)
584KeyTransform.register_lookup(KeyTransformStartsWith)
585KeyTransform.register_lookup(KeyTransformIStartsWith)
586KeyTransform.register_lookup(KeyTransformEndsWith)
587KeyTransform.register_lookup(KeyTransformIEndsWith)
588KeyTransform.register_lookup(KeyTransformRegex)
589KeyTransform.register_lookup(KeyTransformIRegex)
590
591KeyTransform.register_lookup(KeyTransformLt)
592KeyTransform.register_lookup(KeyTransformLte)
593KeyTransform.register_lookup(KeyTransformGt)
594KeyTransform.register_lookup(KeyTransformGte)
595
596
597class KeyTransformFactory:
598 def __init__(self, key_name: str):
599 self.key_name = key_name
600
601 def __call__(self, *args: Any, **kwargs: Any) -> KeyTransform:
602 return KeyTransform(self.key_name, *args, **kwargs)